Update TensorRT-LLM (#667)

* Update TensorRT-LLM

---------

Co-authored-by: 0xymoro <jerrymeng100@gmail.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2023-12-15 22:14:51 +08:00 committed by GitHub
parent f7eca56161
commit a75618df24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
228 changed files with 1063764 additions and 711058 deletions

View File

@ -150,10 +150,13 @@ mkdir -p ./bloom/560M && git clone https://huggingface.co/bigscience/bloom-560m
```
***2. Build the engine***
```python
```bash
# Single GPU on BLOOM 560M
python build.py --model_dir ./bloom/560M/ \
python convert_checkpoint.py --model_dir ./bloom/560M/ \
--dtype float16 \
--output_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/
# May need to add trtllm-build to PATH, export PATH=/usr/local/bin:$PATH
trtllm-build --checkpoint_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/ \
--use_gemm_plugin float16 \
--use_gpt_attention_plugin float16 \
--output_dir ./bloom/560M/trt_engines/fp16/1-gpu/
@ -166,7 +169,7 @@ See the BLOOM [example](examples/bloom) for more details and options regarding t
The `../summarize.py` script can be used to perform the summarization of articles
from the CNN Daily dataset:
```python
```bash
python ../summarize.py --test_trt_llm \
--hf_model_dir ./bloom/560M/ \
--data_type fp16 \
@ -244,10 +247,12 @@ the models listed in the [examples](examples/.) folder.
The list of supported models is:
* [Baichuan](examples/baichuan)
* [BART](examples/enc_dec)
* [Bert](examples/bert)
* [Blip2](examples/blip2)
* [BLOOM](examples/bloom)
* [ChatGLM](examples/chatglm)
* [FairSeq NMT](examples/nmt)
* [Falcon](examples/falcon)
* [Flan-T5](examples/enc_dec)
* [GPT](examples/gpt)
@ -257,6 +262,7 @@ The list of supported models is:
* [InternLM](examples/internlm)
* [LLaMA](examples/llama)
* [LLaMA-v2](examples/llama)
* [mBART](examples/enc_dec)
* [Mistral](examples/llama#mistral-v01)
* [MPT](examples/mpt)
* [mT5](examples/enc_dec)
@ -269,7 +275,7 @@ The list of supported models is:
* [Whisper](examples/whisper)
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, etc. We
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, NMT family, etc. We
unroll the exact model names in the list above to let users find specific
models easier.

View File

@ -96,7 +96,12 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
// copy inputs and wrap into shared_ptr
GenerationInput::TensorPtr inputIds;
std::vector<int32_t> inputsHost(batchSize * maxInputLength, padId);
std::vector<int32_t> inputsHost(batchSize * maxInputLength);
srand(time(0));
for (int i = 0; i < inputsHost.size(); i++)
{
inputsHost[i] = rand() % modelConfig.getVocabSizePadded(worldConfig.getSize());
}
if (inputPacked)
{

View File

@ -12,14 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, Optional
from dataclasses import asdict, dataclass
from typing import Optional
from pydantic import BaseModel, Extra
from tensorrt_llm.functional import PositionEmbeddingType
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
class BuildConfig(BaseModel, extra=Extra.allow):
@dataclass
class BuildConfig:
num_layers: int
num_heads: int
hidden_size: int
@ -41,20 +44,33 @@ class BuildConfig(BaseModel, extra=Extra.allow):
enable_qk_half_accum: bool = False
enable_context_fmha: bool = True
enable_multi_block_mode: bool = False
# The enum name of PositionEmbeddingType
# None means using the model family's default value defined in the ctor
position_embedding_type: Optional[PositionEmbeddingType] = None
position_embedding_type: str = None
# Only when position embedding is RoPE, this value makes sense, make
# default value to be None, not 0 or 1 to prevent misuse
# default value to be None, not the others to prevent misuse
rotary_pct: Optional[float] = None
bias: bool = True
quantization: Optional[str] = None
# use_custom_all_reduce gives better performance with NVLink
use_custom_all_reduce: bool = True
moe_num_experts: int = None
moe_top_k: int = None
moe_num_experts: int = 0
moe_top_k: int = 0
use_alibi: bool = None
remove_input_padding: bool = None
parallel_attention: bool = None
new_decoder_architecture: bool = None
class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
@dataclass
class EncDecBuildConfig:
num_layers: int
num_heads: int
hidden_size: int
vocab_size: int
hidden_act: Optional[str]
n_positions: int
max_batch_size: int
num_decoder_layers: Optional[int] = None
head_size: Optional[int] = None
ffn_hidden_size: Optional[int] = None
@ -62,6 +78,8 @@ class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
max_distance: Optional[int] = None
max_encoder_input_len: Optional[int] = None
max_decoder_input_len: Optional[int] = None
max_output_len: Optional[int] = None
builder_opt: Optional[int] = None
def __post_init__(self) -> None:
assert self.head_size is not None
@ -69,7 +87,8 @@ class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
assert self.num_buckets is not None
class ModelConfig(BaseModel):
@dataclass
class ModelConfig:
name: str
family: str
benchmark_type: Literal["gpt", "bert", "enc_dec"]
@ -192,7 +211,7 @@ _allowed_configs = {
max_input_len=512,
max_output_len=200,
builder_opt=None,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
position_embedding_type='rope_gpt_neox',
rotary_pct=0.5,
bias=False,
)),
@ -285,6 +304,25 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
)),
"llama_7b_moe":
ModelConfig(name="llama_7b_moe",
family="llama",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
num_heads=32,
hidden_size=4096,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
moe_num_experts=4,
moe_top_k=1,
)),
"llama_13b":
ModelConfig(name="llama_13b",
family="llama",
@ -849,7 +887,7 @@ def get_allowed_models(benchmark_type=None):
def get_build_config(model_name):
if model_name in _allowed_configs:
return dict(_allowed_configs[model_name].build_config)
return asdict(_allowed_configs[model_name].build_config)
else:
raise KeyError(f'Unexpected model: {model_name}. Please add the model '
'to allowed_configs.py')
@ -861,3 +899,11 @@ def get_model_family(model_name):
else:
raise KeyError(f'Unexpected model: {model_name}. Please add the model '
'to allowed_configs.py')
def get_benchmark_type(model_name):
if model_name in _allowed_configs:
return _allowed_configs[model_name].benchmark_type
else:
raise KeyError(f'Unexpected model: {model_name}. Please add the model '
'to allowed_configs.py')

View File

@ -124,6 +124,7 @@ class BaseBenchmark(object):
report_dict["model_name"] = self.model_name
report_dict["world_size"] = self.world_size
report_dict["precision"] = self.dtype
report_dict["quantization"] = str(self.quant_mode)
report_dict["compute_cap"] = "sm" + get_compute_cap()
return report_dict

View File

@ -27,7 +27,7 @@ from base_benchmark import get_engine_name, serialize_engine
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.layers import PositionEmbeddingType
from tensorrt_llm.layers import MoeConfig, PositionEmbeddingType
from tensorrt_llm.logger import logger
from tensorrt_llm.models import PretrainedConfig, quantize_model
from tensorrt_llm.network import net_guard
@ -284,12 +284,12 @@ def build_gpt(args):
apply_query_key_layer_scaling,
position_embedding_type=PositionEmbeddingType.learned_absolute
if build_config['position_embedding_type'] is None else
build_config['position_embedding_type'],
PositionEmbeddingType[build_config['position_embedding_type']],
rotary_embedding_percentage=build_config['rotary_pct'],
quant_mode=quant_mode,
bias=build_config['bias'],
moe_layer_config=tensorrt_llm.moe_config.MoeLayerConfig(
build_config["moe_num_experts"], build_config["moe_top_k"]))
moe_config=MoeConfig(build_config["moe_num_experts"],
build_config["moe_top_k"]))
elif family == "opt":
config = {
'architecture': 'OPTForCausalLM',
@ -307,7 +307,29 @@ def build_gpt(args):
'use_parallel_embedding': False,
'share_embedding_table': False,
'embedding_sharding_dim': 0,
'do_layer_norm_before': build_config['do_layer_norm_before']
'do_layer_norm_before': build_config['do_layer_norm_before'],
'quantization': {
'use_smooth_quant':
quant_mode.has_act_and_weight_quant(),
'per_channel':
quant_mode.has_per_channel_scaling(),
'per_token':
quant_mode.has_per_token_dynamic_scaling(),
'per_group':
quant_mode.has_per_group_scaling(),
'group_size':
128,
'int8_kv_cache':
quant_mode.has_int8_kv_cache(),
'enable_fp8':
quant_mode.has_fp8_qdq(),
'fp8_kv_cache':
quant_mode.has_fp8_kv_cache(),
'use_weight_only':
quant_mode.is_weight_only(),
'weight_only_precision':
'int8' if quant_mode.is_int8_weight_only() else 'int4',
}
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.OPTForCausalLM(config)
@ -325,7 +347,9 @@ def build_gpt(args):
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
quant_mode=quant_mode,
use_fused_mlp=True)
use_fused_mlp=True,
moe_config=MoeConfig(build_config["moe_num_experts"],
build_config["moe_top_k"]))
elif family == "gptj":
tensorrt_llm_model = tensorrt_llm.models.GPTJForCausalLM(
num_layers=build_config['num_layers'],
@ -402,17 +426,47 @@ def build_gpt(args):
model_name="chatglm3_6b")
elif family == "bloom":
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
quant_mode=quant_mode,
use_parallel_embedding=(args.model == 'bloom_176b'))
config = {
'architecture': 'BloomForCausalLM',
'dtype': args.dtype,
'vocab_size': build_config['vocab_size'],
'hidden_size': build_config['hidden_size'],
'num_hidden_layers': build_config['num_layers'],
'num_attention_heads': build_config['num_heads'],
'hidden_act': build_config['hidden_act'],
'max_position_embeddings': build_config['n_positions'],
'mapping': {
'world_size': world_size,
'tp_size': world_size
},
'use_parallel_embedding': (args.model == 'bloom_176b'),
'share_embedding_table': False,
'embedding_sharding_dim': 0,
'quantization': {
'use_smooth_quant':
quant_mode.has_act_and_weight_quant(),
'per_channel':
quant_mode.has_per_channel_scaling(),
'per_token':
quant_mode.has_per_token_dynamic_scaling(),
'per_group':
quant_mode.has_per_group_scaling(),
'group_size':
128,
'int8_kv_cache':
quant_mode.has_int8_kv_cache(),
'enable_fp8':
quant_mode.has_fp8_qdq(),
'fp8_kv_cache':
quant_mode.has_fp8_kv_cache(),
'use_weight_only':
quant_mode.is_weight_only(),
'weight_only_precision':
'int8' if quant_mode.is_int8_weight_only() else 'int4',
}
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(config)
elif family == "falcon":
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
num_layers=build_config['num_layers'],
@ -447,8 +501,9 @@ def build_gpt(args):
"zero": True,
"pre_quant_scale": False,
}
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
**quant_kwargs)
if family not in ['opt', 'bloom']:
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
**quant_kwargs)
# Module -> Network
network = builder.create_network()
@ -495,7 +550,10 @@ def build_gpt(args):
max_input_len,
max_output_len, True,
max_beam_width)
tensorrt_llm_model(*inputs)
if family in ['opt', 'bloom']:
tensorrt_llm_model(**inputs)
else:
tensorrt_llm_model(*inputs)
if args.mode == 'plugin':
tensorrt_llm.graph_rewriting.optimize(network)

View File

@ -137,7 +137,7 @@ endif()
message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES}")
enable_language(CUDA)
enable_language(C CXX CUDA)
find_package(CUDAToolkit REQUIRED)

View File

@ -254,7 +254,7 @@ public:
void startScheduling();
//! \brief Assign blocks for new sequence. Try to reuse blocks.
void addSequence(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest);
void addSequence(GenerationRequest& sequence, SizeType inputLength, std::shared_ptr<LlmRequest> const& llmRequest);
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType inputLength, bool enableCyclicKvCache);
@ -449,6 +449,17 @@ public:
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig,
runtime::BufferManager const& bufferManager);
[[nodiscard]] SizeType getNumPrepopulatedTokens(SizeType batchSlotIdx, SizeType beamIdx) const
{
auto const& prepopulatedTokens = mSequences.at(batchSlotIdx)->getNumPrepopulatedTokens();
return prepopulatedTokens.size() > 0 ? prepopulatedTokens.at(beamIdx) : 0;
}
[[nodiscard]] bool isEnableBlockReuse() const
{
return mEnableBlockReuse;
}
private:
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);

View File

@ -34,21 +34,18 @@ public:
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
bool useContextFMHAForGeneration = false,
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt)
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt)
: kvCacheConfig{kvCacheConfig}
, maxNumSequences{maxNumSequences}
, enableTrtOverlap{enableTrtOverlap}
, useContextFMHAForGeneration(useContextFMHAForGeneration)
, userSpecifiedDeviceIds(userSpecifiedDeviceIds)
, deviceIds(deviceIds)
{
}
KvCacheConfig kvCacheConfig;
std::optional<SizeType> maxNumSequences;
bool enableTrtOverlap;
bool useContextFMHAForGeneration;
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds;
std::optional<std::vector<SizeType>> deviceIds;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -56,6 +56,8 @@ public:
, mUseCustomAllReduce(false)
, mMaxPromptEmbeddingTableSize(0)
, mMaxDraftLen(0)
, mUseContextFMHAForGeneration(false)
, mPagedContextFMHA(false)
{
}
@ -280,6 +282,26 @@ public:
return mMaxDraftLen + 1;
}
void constexpr setUseContextFMHAForGeneration(bool useContextFMHAForGeneration) noexcept
{
mUseContextFMHAForGeneration = useContextFMHAForGeneration;
}
[[nodiscard]] bool constexpr getContextFMHAForGeneration() const noexcept
{
return mUseContextFMHAForGeneration;
}
void constexpr setPagedContextFMHA(bool pagedContextFMHA) noexcept
{
mPagedContextFMHA = pagedContextFMHA;
}
[[nodiscard]] bool constexpr getPagedContextFMHA() const noexcept
{
return mPagedContextFMHA;
}
private:
SizeType mVocabSize;
SizeType mNbLayers;
@ -305,6 +327,9 @@ private:
SizeType mMaxPromptEmbeddingTableSize;
SizeType mMaxDraftLen;
bool mUseContextFMHAForGeneration;
bool mPagedContextFMHA;
};
} // namespace tensorrt_llm::runtime

View File

@ -30,14 +30,8 @@ public:
static SizeType constexpr kDefaultGpusPerNode = 8;
explicit WorldConfig(SizeType tensorParallelism = 1, SizeType pipelineParallelism = 1, SizeType rank = 0,
SizeType gpusPerNode = kDefaultGpusPerNode, std::vector<SizeType> deviceIds = {})
: mTensorParallelism{tensorParallelism}
, mPipelineParallelism{pipelineParallelism}
, mRank{rank}
, mGpusPerNode{gpusPerNode}
, mDeviceIds{deviceIds}
{
}
SizeType gpusPerNode = kDefaultGpusPerNode,
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt);
[[nodiscard]] SizeType constexpr getSize() const noexcept
{
@ -74,13 +68,14 @@ public:
return mGpusPerNode;
}
[[nodiscard]] SizeType getGpusPerGroup() const noexcept
{
return static_cast<SizeType>(mDeviceIds.size());
}
[[nodiscard]] SizeType getDevice() const noexcept
{
if (mDeviceIds.size())
{
return mDeviceIds[mRank % mGpusPerNode];
}
return mRank % mGpusPerNode;
return mDeviceIds[mRank % getGpusPerGroup()];
}
[[nodiscard]] SizeType constexpr getPipelineParallelRank() const noexcept
@ -116,12 +111,12 @@ public:
static WorldConfig mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode = kDefaultGpusPerNode,
std::optional<SizeType> tensorParallelism = std::nullopt,
std::optional<SizeType> pipelineParallelism = std::nullopt,
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt);
static WorldConfig mpi(SizeType gpusPerNode = kDefaultGpusPerNode,
std::optional<SizeType> tensorParallelism = std::nullopt,
std::optional<SizeType> pipelineParallelism = std::nullopt,
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt);
private:
SizeType mTensorParallelism;

View File

@ -24,11 +24,11 @@ set(STATIC_TARGET
set(API_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include)
find_package(MPI REQUIRED)
message(STATUS "Using MPI_CXX_INCLUDE_DIRS: ${MPI_CXX_INCLUDE_DIRS}")
message(STATUS "Using MPI_CXX_LIBRARIES: ${MPI_CXX_LIBRARIES}")
message(STATUS "Using MPI_C_INCLUDE_DIRS: ${MPI_C_INCLUDE_DIRS}")
message(STATUS "Using MPI_C_LIBRARIES: ${MPI_C_LIBRARIES}")
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cutlass_extensions/include
${API_INCLUDE_DIR} ${MPI_INCLUDE_PATH})
${API_INCLUDE_DIR} ${MPI_C_INCLUDE_DIRS})
add_subdirectory(common)
add_subdirectory(kernels)
@ -114,7 +114,7 @@ set(TRTLLM_LINK_LIBS
${CUBLASLT_LIB}
${CUDNN_LIB}
${CMAKE_DL_LIBS}
${MPI_CXX_LIBRARIES}
${MPI_C_LIBRARIES}
${NCCL_LIB}
${TRT_LIB}
common_src

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7d9f7d0f7dee2c48a424ff8873c2fd1298a27850f870657734641f2eb1190faf
size 1791038
oid sha256:51f905eed7ac6f5dbf12736519961100b8ac5f270cb96a79dd74c8f0a6837f24
size 1801452

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa79a0d563fc01a0cb2fe94dcb626ff4e5b736284d9244313cbe7aa0261dd48e
size 1806500
oid sha256:17ea5ea3b9cf666091a2997da2c7980e36a8b59b320668c5453886fb766c2db5
size 1819266

View File

@ -1,3 +1,3 @@
d9723ab671c9fc3889cc624a58def81a libtensorrt_llm_batch_manager_static.a
4b6773c990e8a59f1c716d88505b84a2 libtensorrt_llm_batch_manager_static.pre_cxx11.a
9a136bb59c51bbae09221c1667e23529ed05c752 commit
d7a189cfdfbc3ebe45f07faf7be61434 libtensorrt_llm_batch_manager_static.a
f18c7b1d843bbcc48a63fb0eedc3ade5 libtensorrt_llm_batch_manager_static.pre_cxx11.a
5218d255437b68c905d8be35a959405c20e6254c commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6a7b872fe6ee63a4342c3cd17b3557d74c72e537dbf0d4ddf132a2c40e000e57
size 1709462
oid sha256:1d2fd4c684ea3de95fb1070e28e2938760d024fec0bd5d710585c3f48c659b8f
size 1721606

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c83f7c0e4fc22b32df669ada2b99b88f0f7faac935a251fe7a20030e2b364cc8
size 1705432
oid sha256:ec5f659d47742f96f36385cfac33a5a5d4159acf0ee029d9c01c56eb5afb9afc
size 1715582

View File

@ -1,2 +1,2 @@
583141c3003a08acebc7054d024bee89 libtensorrt_llm_batch_manager_static.a
03e5360f9b8074b8273500898581212f libtensorrt_llm_batch_manager_static.pre_cxx11.a
723277a6fc05a0589b96a66893a8518b libtensorrt_llm_batch_manager_static.a
40defd0eb709bea4c0b8e363fd264142 libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -48,7 +48,7 @@ public:
template <typename T>
[[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, const bool setZero = true)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
// TODO martinma: why do we need this size extension?
auto const sizeAligned = ((sizeBytes + 31) / 32) * 32; // make the buffer align with 32 bytes
if (contains(ptr))

View File

@ -48,7 +48,7 @@ ReallocType CudaAllocator::reallocType(void const* ptr, size_t size) const
void* CudaAllocator::malloc(std::size_t size, bool const setZero)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
auto bufferPtr = mBufferManager.gpu(size);
if (setZero)
{
@ -62,7 +62,7 @@ void* CudaAllocator::malloc(std::size_t size, bool const setZero)
void CudaAllocator::free(void** ptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mPointerMapping.erase(*ptr);
*ptr = nullptr;
}

View File

@ -297,9 +297,12 @@ inline int getMaxSharedMemoryPerBlockOptin()
return max_shared_memory_per_block;
}
inline int divUp(int a, int n)
template <typename T1, typename T2>
inline size_t divUp(const T1& a, const T2& n)
{
return (a + n - 1) / n;
size_t tmp_a = static_cast<size_t>(a);
size_t tmp_n = static_cast<size_t>(n);
return (tmp_a + tmp_n - 1) / tmp_n;
}
template <typename T, typename U, typename = std::enable_if_t<std::is_integral<T>::value>,

View File

@ -0,0 +1,65 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "envUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <cstdlib>
namespace tensorrt_llm::common
{
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug()
{
static bool init = false;
static bool forceMmhaMaxSeqLenTile = false;
if (!init)
{
init = true;
const char* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
if (enable_mmha_debug_var)
{
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
{
forceMmhaMaxSeqLenTile = true;
}
}
}
return forceMmhaMaxSeqLenTile;
}
int getEnvMmhaBlocksPerSequence()
{
static bool init = false;
static int mmhaBlocksPerSequence = 0;
if (!init)
{
init = true;
const char* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
if (mmhaBlocksPerSequenceEnv)
{
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
if (mmhaBlocksPerSequence <= 0)
{
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!");
}
}
}
return mmhaBlocksPerSequence;
}
} // namespace tensorrt_llm::common

View File

@ -0,0 +1,28 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace tensorrt_llm::common
{
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();
int getEnvMmhaBlocksPerSequence();
} // namespace tensorrt_llm::common

View File

@ -233,7 +233,7 @@ public:
template <typename T>
inline T getVal(size_t index) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(where == MEMORY_CPU);
TLLM_CHECK(data != nullptr);
TLLM_CHECK_WITH_INFO(index < size(), "index is larger than buffer size");
@ -249,7 +249,7 @@ public:
template <typename T>
inline T getVal() const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (getTensorType<T>() != type)
{
TLLM_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType<T>()).c_str(),
@ -261,7 +261,7 @@ public:
template <typename T>
inline T* getPtr() const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (getTensorType<T>() != type)
{
TLLM_LOG_DEBUG("getPtr with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType<T>()).c_str(),
@ -272,7 +272,7 @@ public:
inline void* getPtrWithOffset(size_t offset) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (data == nullptr)
{
return (void*) data;
@ -287,7 +287,7 @@ public:
template <typename T>
inline T* getPtrWithOffset(size_t offset) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (getTensorType<T>() != type)
{
TLLM_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType<T>()).c_str(),
@ -431,7 +431,7 @@ public:
inline bool contains(const std::string& key) const
{
TLLM_LOG_DEBUG("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
return tensor_map_.find(key) != tensor_map_.end();
}
@ -464,7 +464,7 @@ public:
inline Tensor& at(const std::string& key)
{
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
TLLM_CHECK_WITH_INFO(contains(key),
fmtstr(
"Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str()));
@ -489,7 +489,7 @@ public:
inline Tensor& at(const std::string& key, Tensor& default_tensor)
{
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
if (contains(key))
{
return tensor_map_.at(key);
@ -499,7 +499,7 @@ public:
inline Tensor at(const std::string& key, Tensor& default_tensor) const
{
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
if (contains(key))
{
return tensor_map_.at(key);
@ -509,7 +509,7 @@ public:
inline Tensor& at(const std::string& key, Tensor&& default_tensor)
{
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
if (contains(key))
{
return tensor_map_.at(key);

View File

@ -50,7 +50,7 @@ TllmException::TllmException(char const* file, std::size_t line, const std::stri
}
#endif
TllmException::~TllmException() = default;
TllmException::~TllmException() noexcept = default;
std::string TllmException::getTrace() const
{

View File

@ -0,0 +1,82 @@
/*
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
#include <cstdint>
namespace tensorrt_llm::common
{
std::uintptr_t constexpr kCudaMemAlign = 128;
namespace
{
int8_t* alignPtr(int8_t* ptr, uintptr_t to)
{
uintptr_t addr = (uintptr_t) ptr;
if (addr % to)
{
addr += to - addr % to;
}
return (int8_t*) addr;
}
int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, const uintptr_t alignment)
{
uintptr_t addr = (uintptr_t) ptr;
addr += previousWorkspaceSize;
return alignPtr((int8_t*) addr, alignment);
}
int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign);
}
int8_t* nextWorkspacePtr(
int8_t* const base, uintptr_t& offset, const uintptr_t size, const uintptr_t alignment = kCudaMemAlign)
{
uintptr_t curr_offset = offset;
uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment;
int8_t* newptr = size == 0 ? nullptr : base + curr_offset;
offset = next_offset;
return newptr;
}
int8_t* nextWorkspacePtrWithAlignment(
int8_t* ptr, uintptr_t previousWorkspaceSize, const uintptr_t alignment = kCudaMemAlign)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
}
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count, const uintptr_t alignment = kCudaMemAlign)
{
size_t total = 0;
for (int i = 0; i < count; i++)
{
total += workspaces[i];
if (workspaces[i] % alignment)
{
total += alignment - (workspaces[i] % alignment);
}
}
return total;
}
} // namespace
}; // namespace tensorrt_llm::common

View File

@ -44,7 +44,7 @@ struct MixedGemmArchTraits<float, float, arch>
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;

View File

@ -52,7 +52,7 @@ template <typename TypeB>
struct LayoutDetailsB<TypeB, arch::Sm70>
{
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 8;
using Operator = cutlass::arch::OpMultiplyAdd;
};
@ -63,7 +63,7 @@ template <typename Arch>
struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
@ -72,7 +72,7 @@ template <typename Arch>
struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};

View File

@ -507,6 +507,10 @@ public:
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif

View File

@ -90,22 +90,24 @@ public:
pagedKVXmmaKernel = getPagedKVXMMAKernelsV2(mDataType, sm);
xmmaKernel = getXMMAKernelsV2(mDataType, sm);
params.clear();
pagedKVParams.clear();
mParams.clear();
mPagedKVParams.clear();
// get device attributes
int device_id;
cudaGetDevice(&device_id);
cudaDeviceGetAttribute(&launch_params.multi_processor_count, cudaDevAttrMultiProcessorCount, device_id);
cudaDeviceGetAttribute(&launch_params.device_l2_cache_size, cudaDevAttrL2CacheSize, device_id);
cudaDeviceGetAttribute(&mLaunchParams.multi_processor_count, cudaDevAttrMultiProcessorCount, device_id);
cudaDeviceGetAttribute(&mLaunchParams.device_l2_cache_size, cudaDevAttrL2CacheSize, device_id);
}
~mhaImpl() {}
// Support packed QKV.
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank)
// Shared setup function.
template <typename Params>
void setup_params(Params& params, const int b, const int s_q, const int s_kv, const int sliding_window_size,
const int total_seqlen, const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank)
{
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
// Note that we apply scales and bias in the order of
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
@ -114,79 +116,108 @@ public:
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
Data_type scale_type = launch_params.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
set_alpha(params.scale_bmm1, scale_bmm1, scale_type);
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
// Use specialized ws kernels on Hopper for cases without alibi.
if (mLaunchParams.useKernelWithoutAlibi)
{
// The kernel adopts the log2f optimziation.
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
set_alpha(params.scale_bmm1, scale_bmm1 * float(kLog2e), DATA_TYPE_FP32);
}
else
{
set_alpha(params.scale_bmm1, scale_bmm1, scale_type);
}
set_alpha(params.scale_softmax, scale_softmax, scale_type);
set_alpha(params.scale_bmm2, scale_bmm2, scale_type);
params.b = b;
params.h = mNumHeads;
params.s = s;
params.s = s_q;
params.d = mHeadSize;
params.sliding_window_size = sliding_window_size;
params.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
// Total sequence length needed by TMA descriptor
// it should be actual total seq length if non-padded input is given.
mTotalSeqLen = total_seqlen;
params.qkv_stride_in_bytes = (mNumHeads + 2 * params.h_kv) * mHeadSize * sizeof(half);
params.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
const bool isSm90 = (sm == kSM_90);
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
const bool isSm80 = (sm == kSM_80);
if (isSm90 && params.d <= 64 && params.s <= 256)
{
launch_params.flash_attention = false;
// get max sequence length for non-flash-attentio
launch_params.kernel_s = getSFromMaxSeqLen(params.s);
}
else
{ // always use flash attention kernels for Ampere/Ada
launch_params.flash_attention = true;
// flash attention kernles s = 0 (support any seq length)
launch_params.kernel_s = 0;
launch_params.force_unroll = true;
// enable tiled kernels on Ampere/Ada
if (launch_params.flash_attention && params.s <= 64)
{
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
launch_params.granular_tiling = false;
}
else if (isSm8x && params.d < 256)
{
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
launch_params.granular_tiling = false;
}
else if (isSm80 || isSm8x)
{
// otherwise, choose tiled kernel for Ampere/Ada
launch_params.granular_tiling = true;
}
}
// when flash attention is enabled on Hopper, we need to set the tma descriptors
if (isSm90 && launch_params.flash_attention)
{
launch_params.warp_specialization = true;
launch_params.use_tma = true;
}
// alibi.
if (has_alibi)
{
params.has_alibi = true;
params.alibi_params = AlibiParams(mNumHeads, s, tp_size, tp_rank, scale_after_alibi);
params.alibi_params = AlibiParams(mNumHeads, s_kv, tp_size, tp_rank, scale_after_alibi);
}
}
// Support packed QKV.
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank)
{
// Determine launch parameters.
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
mLaunchParams.set_default_kernel_selection_params();
const bool isSm90 = (sm == kSM_90);
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
const bool isSm80 = (sm == kSM_80);
if (isSm90 && mHeadSize <= 64 && s <= 256)
{
mLaunchParams.flash_attention = false;
// get max sequence length for non-flash-attentio
mLaunchParams.kernel_s = getSFromMaxSeqLen(s);
}
else
{ // always use flash attention kernels for Ampere/Ada
mLaunchParams.flash_attention = true;
// flash attention kernles s = 0 (support any seq length)
mLaunchParams.kernel_s = 0;
mLaunchParams.force_unroll = true;
// enable tiled kernels on Ampere/Ada
if (mLaunchParams.flash_attention && s <= 64)
{
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
mLaunchParams.granular_tiling = false;
}
else if (isSm8x && mHeadSize < 256)
{
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
mLaunchParams.granular_tiling = false;
}
else if (isSm80 || isSm8x)
{
// otherwise, choose tiled kernel for Ampere/Ada
mLaunchParams.granular_tiling = true;
}
}
// when flash attention is enabled on Hopper, we need to set the tma descriptors
if (isSm90 && mLaunchParams.flash_attention)
{
mLaunchParams.warp_specialization = true;
mLaunchParams.use_tma = true;
}
// Use specialized ws kernels on Hopper for cases without alibi.
if (mLaunchParams.warp_specialization && !has_alibi)
{
// Use specialized ws kernels for cases without alibi.
mLaunchParams.useKernelWithoutAlibi = true;
}
// Sliding_window_causal mask.
if (s > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL)
if (s > sliding_window_size && mLaunchParams.attention_mask_type == ContextAttentionMaskType::CAUSAL)
{
params.sliding_window_size = sliding_window_size;
launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
mLaunchParams.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
}
// Set kernel parameters.
setup_params(mParams, b, s, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
mParams.qkv_stride_in_bytes = (mNumHeads + 2 * mParams.h_kv) * mHeadSize * sizeof(half);
}
// Support paged_kv_cache and chunked_attention.
@ -194,34 +225,13 @@ public:
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank)
{
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
// Note that we apply scales and bias in the order of
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
Data_type scale_type = launch_params.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
set_alpha(pagedKVParams.scale_bmm1, scale_bmm1, scale_type);
set_alpha(pagedKVParams.scale_softmax, scale_softmax, scale_type);
set_alpha(pagedKVParams.scale_bmm2, scale_bmm2, scale_type);
pagedKVParams.b = b;
pagedKVParams.h = mNumHeads;
pagedKVParams.s = s_q;
pagedKVParams.d = mHeadSize;
// Total sequence length needed by TMA descriptor
// it should be actual total seq length if non-padded input is given.
mTotalSeqLen = total_seqlen;
// Determine launch parameters.
mLaunchParams.set_default_kernel_selection_params();
TLLM_CHECK_WITH_INFO(tokens_per_kv_block >= 128, "FMHA with paged kv cache needs tokens_per_block >= 128 !");
// Needed by TMA descriptors.
launch_params.blocks_per_context_sequence = blocks_per_context_sequence;
pagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
pagedKVParams.kv_stride_in_bytes = tokens_per_kv_block * mHeadSize * sizeof(half);
pagedKVParams.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
mLaunchParams.blocks_per_context_sequence = blocks_per_context_sequence;
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
const bool isSm90 = (sm == kSM_90);
@ -229,53 +239,57 @@ public:
const bool isSm80 = (sm == kSM_80);
// always use flash attention kernels.
launch_params.flash_attention = true;
mLaunchParams.flash_attention = true;
// flash attention kernles s = 0 (support any seq length)
launch_params.kernel_s = 0;
launch_params.kernel_kv_s = s_kv;
launch_params.force_unroll = true;
mLaunchParams.kernel_s = 0;
mLaunchParams.kernel_kv_s = s_kv;
mLaunchParams.force_unroll = true;
// enable warp-specialization kernels when s > 512.
if (isSm90 && s_kv > 512)
{
launch_params.warp_specialization = true;
launch_params.use_tma = true;
mLaunchParams.warp_specialization = true;
mLaunchParams.use_tma = true;
}
else
{
// enable tiled kernels on Ampere/Ada
if (launch_params.flash_attention && s_kv <= 64)
if (mLaunchParams.flash_attention && s_kv <= 64)
{
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
launch_params.granular_tiling = false;
mLaunchParams.granular_tiling = false;
}
else if (isSm8x && params.d < 256)
else if (isSm8x && mParams.d < 256)
{
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
launch_params.granular_tiling = false;
mLaunchParams.granular_tiling = false;
}
else if (isSm90 || isSm80 || isSm8x)
{
// otherwise, choose tiled kernel for Ampere/Ada
launch_params.granular_tiling = true;
mLaunchParams.granular_tiling = true;
}
}
// alibi.
if (has_alibi)
// Use specialized ws kernels on Hopper for cases without alibi.
if (mLaunchParams.warp_specialization && !has_alibi)
{
pagedKVParams.has_alibi = true;
pagedKVParams.alibi_params = AlibiParams(mNumHeads, s_kv, tp_size, tp_rank, scale_after_alibi);
// Use specialized ws kernels for cases without alibi.
mLaunchParams.useKernelWithoutAlibi = true;
}
// Sliding_window_causal mask.
if (s_kv > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL)
if (s_kv > sliding_window_size && mLaunchParams.attention_mask_type == ContextAttentionMaskType::CAUSAL)
{
pagedKVParams.sliding_window_size = sliding_window_size;
launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
mLaunchParams.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
}
setup_params(
mPagedKVParams, b, s_q, s_kv, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
mPagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
mPagedKVParams.kv_stride_in_bytes = tokens_per_kv_block * mHeadSize * sizeof(half);
}
// NOTE: assume that heads_interleaved = false (b, s, 3, h, d), and sequences are padded/non-padded
@ -283,7 +297,7 @@ public:
void set_tma_descriptors()
{
// split D into multiple groups in order to match the TMA swizzle mode (128B)
const uint32_t d_in_bytes = params.d * sizeof(uint16_t);
const uint32_t d_in_bytes = mParams.d * sizeof(uint16_t);
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
// separate q, k, and v tma descriptors
@ -291,18 +305,18 @@ public:
// tensor size
uint32_t tensor_size_qkv[4];
if (params.h_kv < params.h)
if (mParams.h_kv < mParams.h)
{
// if multi-query or grouped-query
tensor_size_qkv[2] = 1;
tensor_size_qkv[1] = (params.h + 2 * params.h_kv);
tensor_size_qkv[0] = params.d; // params.d;
tensor_size_qkv[1] = (mParams.h + 2 * mParams.h_kv);
tensor_size_qkv[0] = mParams.d; // mParams.d;
}
else
{
tensor_size_qkv[2] = 3;
tensor_size_qkv[1] = params.h;
tensor_size_qkv[0] = params.d; // params.d;
tensor_size_qkv[1] = mParams.h;
tensor_size_qkv[0] = mParams.d; // mParams.d;
}
// box size for k and v
@ -310,7 +324,7 @@ public:
// Update this on device?
box_size[2] = 1;
box_size[1] = 1;
box_size[0] = params.d / d_groups;
box_size[0] = mParams.d / d_groups;
// stride size in bytes. Assumes least significant dim is 1 (?)
uint64_t tensor_stride_qkv[3];
@ -328,7 +342,7 @@ public:
uint32_t fp32_to_tf32 = 0;
// gmma descriptor mode
const uint32_t d_bytes_per_group = (params.d * sizeof(uint16_t)) / d_groups;
const uint32_t d_bytes_per_group = (mParams.d * sizeof(uint16_t)) / d_groups;
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
? cudaTmaDescSwizzle::SWIZZLE_128B
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
@ -336,7 +350,7 @@ public:
uint32_t q_step = 0, kv_step = 0;
for (unsigned int i = 0u; i < sizeof(sTmaMetaInfo) / sizeof(sTmaMetaInfo[0]); ++i)
{
if (sTmaMetaInfo[i].mD == params.d)
if (sTmaMetaInfo[i].mD == mParams.d)
{
q_step = sTmaMetaInfo[i].mQStep;
kv_step = sTmaMetaInfo[i].mKVStep;
@ -346,7 +360,7 @@ public:
// QKV [TOTAL, 3, h, d]
// NOTE: we may need to use actual seqlen to set oob_value
const char* qkv_ptr = reinterpret_cast<const char*>(params.qkv_ptr);
const char* qkv_ptr = reinterpret_cast<const char*>(mParams.qkv_ptr);
tensor_size_qkv[3] = mTotalSeqLen;
// Q: STEP_Q
@ -354,18 +368,18 @@ public:
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
&params.tma_desc_q);
&mParams.tma_desc_q);
// K/V: STEP_KV
box_size[3] = kv_step;
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
&params.tma_desc_k);
&mParams.tma_desc_k);
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
&params.tma_desc_v);
&mParams.tma_desc_v);
}
// Q are contiguous in the shape of [B, S, H, D]
@ -375,13 +389,13 @@ public:
void set_paged_kv_tma_descriptors(cudaStream_t stream)
{
// split D into multiple groups in order to match the TMA swizzle mode (128B)
const uint32_t d_in_bytes = pagedKVParams.d * sizeof(uint16_t);
const uint32_t d_in_bytes = mPagedKVParams.d * sizeof(uint16_t);
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
uint32_t q_step = 0, kv_step = 0;
for (unsigned int i = 0u; i < sizeof(sTmaPagedKVMetaInfo) / sizeof(sTmaPagedKVMetaInfo[0]); ++i)
{
if (sTmaPagedKVMetaInfo[i].mD == pagedKVParams.d)
if (sTmaPagedKVMetaInfo[i].mD == mPagedKVParams.d)
{
q_step = sTmaPagedKVMetaInfo[i].mQStep;
kv_step = sTmaPagedKVMetaInfo[i].mKVStep;
@ -392,21 +406,21 @@ public:
// Separate q, and paged kv tma descriptors.
Multiple_tma_descriptor<4> q_tma_descriptor;
Multiple_tma_descriptor<4> paged_kv_tma_descriptor(
pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence);
mPagedKVParams.b * 2 * mLaunchParams.blocks_per_context_sequence);
// Contiguous Q
// query tensor size [B x S, 1, H, D]
uint32_t tensor_size_q[4];
tensor_size_q[3] = mTotalSeqLen;
tensor_size_q[2] = 1;
tensor_size_q[1] = pagedKVParams.h;
tensor_size_q[0] = pagedKVParams.d;
tensor_size_q[1] = mPagedKVParams.h;
tensor_size_q[0] = mPagedKVParams.d;
// box size for k and v
uint32_t box_size_q[4];
box_size_q[3] = q_step;
box_size_q[2] = 1;
box_size_q[1] = 1;
box_size_q[0] = pagedKVParams.d / d_groups;
box_size_q[0] = mPagedKVParams.d / d_groups;
// stride size in bytes.
uint64_t tensor_stride_q[3];
@ -424,34 +438,34 @@ public:
uint32_t fp32_to_tf32 = 0;
// gmma descriptor mode
const uint32_t d_bytes_per_group = (pagedKVParams.d * sizeof(uint16_t)) / d_groups;
const uint32_t d_bytes_per_group = (mPagedKVParams.d * sizeof(uint16_t)) / d_groups;
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
? cudaTmaDescSwizzle::SWIZZLE_128B
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
// Q ptr.
const char* q_ptr = reinterpret_cast<const char*>(pagedKVParams.q_ptr);
const char* q_ptr = reinterpret_cast<const char*>(mPagedKVParams.q_ptr);
// Q: STEP_Q.
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_q, tensor_stride_q, traversal_stride, box_size_q, oob_fill, fp32_to_tf32,
&pagedKVParams.tma_desc_q);
&mPagedKVParams.tma_desc_q);
// Paged KV
// Per batch tensor size.
uint32_t tensor_size_kv[4];
tensor_size_kv[3] = 1;
tensor_size_kv[2] = pagedKVParams.h_kv;
tensor_size_kv[1] = pagedKVParams.paged_kv_cache.mTokensPerBlock;
tensor_size_kv[0] = pagedKVParams.d;
tensor_size_kv[2] = mPagedKVParams.h_kv;
tensor_size_kv[1] = mPagedKVParams.paged_kv_cache.mTokensPerBlock;
tensor_size_kv[0] = mPagedKVParams.d;
// Box size for k and v.
uint32_t box_size_kv[4];
box_size_kv[3] = 1;
box_size_kv[2] = 1;
box_size_kv[1] = kv_step;
box_size_kv[0] = pagedKVParams.d / d_groups;
box_size_kv[0] = mPagedKVParams.d / d_groups;
// Stride size in bytes.
uint64_t tensor_stride_kv[3];
@ -461,40 +475,40 @@ public:
// 2 stands for k, and v blocks.
// We only need to prepare as many tma descriptos as the number of paged kv blocks for context.
for (int block_idx = 0; block_idx < pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence;
for (int block_idx = 0; block_idx < mPagedKVParams.b * 2 * mLaunchParams.blocks_per_context_sequence;
block_idx++)
{
int block_ptr_idx = int(block_idx / launch_params.blocks_per_context_sequence)
* pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq
+ (block_idx % launch_params.blocks_per_context_sequence);
int block_ptr_idx = int(block_idx / mLaunchParams.blocks_per_context_sequence)
* mPagedKVParams.paged_kv_cache.mMaxBlocksPerSeq
+ (block_idx % mLaunchParams.blocks_per_context_sequence);
paged_kv_tma_descriptor.set_tma_desctriptor(
reinterpret_cast<char*>(launch_params.paged_kv_block_ptrs[block_ptr_idx]), cudaTmaDescFormat::F16_RN,
reinterpret_cast<char*>(mLaunchParams.paged_kv_block_ptrs[block_ptr_idx]), cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_kv, tensor_stride_kv, traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, block_idx);
}
// set mMaxBlocksPerSeq to the number of blocks needed for context.
pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq = launch_params.blocks_per_context_sequence;
mPagedKVParams.paged_kv_cache.mMaxBlocksPerSeq = mLaunchParams.blocks_per_context_sequence;
paged_kv_tma_descriptor.copy_to_device(pagedKVParams.tma_desc_paged_kv, stream);
paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream);
}
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
{
// BF16 FMHA only accumulates on FP32
launch_params.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
launch_params.attention_mask_type
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
mLaunchParams.attention_mask_type
= causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING;
// Paged KV Cache.
pagedKVParams.h_kv = num_kv_heads;
mPagedKVParams.h_kv = num_kv_heads;
TLLM_CHECK_WITH_INFO(mNumHeads % num_kv_heads == 0, "number of Query heads should be multiple of KV heads !");
pagedKVParams.h_q_per_kv = mNumHeads / num_kv_heads;
pagedKVParams.is_s_padded = is_s_padded;
mPagedKVParams.h_q_per_kv = mNumHeads / num_kv_heads;
mPagedKVParams.is_s_padded = is_s_padded;
// Contiguous Cache.
params.h_kv = num_kv_heads;
params.is_s_padded = is_s_padded;
mParams.h_kv = num_kv_heads;
mParams.is_s_padded = is_s_padded;
}
bool fmha_supported()
@ -504,39 +518,39 @@ public:
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
{
params.qkv_ptr = qkvPtr;
params.o_ptr = outputPtr;
params.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
mParams.qkv_ptr = qkvPtr;
mParams.o_ptr = outputPtr;
mParams.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
if (sm == kSM_90 && launch_params.use_tma)
if (sm == kSM_90 && mLaunchParams.use_tma)
{
// memcpy H2D has been removed by applying grid_constant tma descriptors.
set_tma_descriptors();
}
xmmaKernel->run(params, launch_params, stream);
xmmaKernel->run(mParams, mLaunchParams, stream);
}
void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
cudaStream_t stream)
{
pagedKVParams.q_ptr = qPtr;
pagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
pagedKVParams.paged_kv_cache = pagedKVCache;
pagedKVParams.o_ptr = outputPtr;
pagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
pagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
mPagedKVParams.q_ptr = qPtr;
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
mPagedKVParams.paged_kv_cache = pagedKVCache;
mPagedKVParams.o_ptr = outputPtr;
mPagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
mPagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
// paged kv block device ptrs on host (used by tma descriptors).
launch_params.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
if (sm == kSM_90 && launch_params.use_tma)
if (sm == kSM_90 && mLaunchParams.use_tma)
{
// memcpy H2D is needed as we use multiple tma descriptors in device memory.
set_paged_kv_tma_descriptors(stream);
}
pagedKVXmmaKernel->run(pagedKVParams, launch_params, stream);
pagedKVXmmaKernel->run(mPagedKVParams, mLaunchParams, stream);
}
bool isValid(int s) const
@ -578,9 +592,9 @@ public:
}
private:
Fused_multihead_attention_params_v2 params;
Fused_multihead_attention_paged_kv_params_v2 pagedKVParams;
Launch_params launch_params;
Fused_multihead_attention_params_v2 mParams;
Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams;
Launch_params mLaunchParams;
int sm;
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel;

View File

@ -301,9 +301,26 @@ struct Launch_params
bool granular_tiling = false;
// mask type: padding, causal, sliding_window_causal
ContextAttentionMaskType attention_mask_type = ContextAttentionMaskType::PADDING;
// use specialized kernels without alibi support.
bool useKernelWithoutAlibi = false;
// harward properties to determine how to launch blocks
int multi_processor_count = 0;
int device_l2_cache_size = 0;
void set_default_kernel_selection_params()
{
kernel_s = 0;
kernel_kv_s = 0;
force_unroll = false;
use_tma = false;
flash_attention = false;
warp_specialization = false;
granular_tiling = false;
attention_mask_type = (attention_mask_type == ContextAttentionMaskType::PADDING)
? ContextAttentionMaskType::PADDING
: ContextAttentionMaskType::CAUSAL;
useKernelWithoutAlibi = false;
}
};
} // namespace kernels

View File

@ -232,20 +232,21 @@ public:
}
inline uint64_t hashID(unsigned int s, unsigned int d, bool interleaved, bool unroll, bool force_fp32_acc,
bool flash_attention, int attention_mask_type, bool tiled) const
bool flash_attention, bool is_alibi_supported, int attention_mask_type, bool tiled) const
{
s = flash_attention ? 0 : s;
// D <= 2048
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 5) | (tiled ? 16ull : 0ull)
| (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull)
| (unroll ? 1ull : 0ull);
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 6) | (is_alibi_supported ? 32ull : 0ull)
| (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull)
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
}
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
{
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mAlibiSupported,
kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
}
virtual void run(
@ -288,10 +289,17 @@ public:
}
}
const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
forceUnroll, launch_params.force_fp32_acc, launch_params.flash_attention,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "FMHA kernels are not found");
const auto findIter
= mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll,
launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
// Add debug info when kernels are not found.
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(),
"FMHA kernels are not found (kernel meta info: %d %d %d %d %d %d %d %d %d) !", launch_params.kernel_s,
params.d, launch_params.interleaved, forceUnroll, launch_params.force_fp32_acc,
launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling);
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
const CUfunction func = findIter->second.mDeviceFunction;
@ -383,31 +391,39 @@ public:
}
inline uint64_t hashID(unsigned int s, unsigned int d, bool interleaved, bool unroll, bool force_fp32_acc,
bool flash_attention, bool warp_specialization, int attention_mask_type, bool tiled) const
bool flash_attention, bool warp_specialization, bool is_alibi_supported, int attention_mask_type,
bool tiled) const
{
s = flash_attention ? 0 : s;
// D <= 2048
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 6) | (warp_specialization ? 16ull : 0ull)
| (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull)
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 7) | (is_alibi_supported ? 64ull : 0ull)
| (warp_specialization ? 32ull : 0ull) | (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull)
| (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
}
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
{
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
kernelMeta.mAlibiSupported, kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
}
virtual void run(
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
{
const auto findIter = mFunctions.find(
hashID(launch_params.kernel_s, params.d, launch_params.interleaved, launch_params.force_unroll,
launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.warp_specialization,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "FMHA kernels are not found");
const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention,
launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
// Add debug info when kernels are not found.
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(),
"Paged KV FMHA kernels are not found (kernel meta info: %d %d %d %d %d %d %d %d %d %d) !",
launch_params.kernel_s, params.d, launch_params.interleaved, launch_params.force_unroll,
launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.warp_specialization,
!launch_params.useKernelWithoutAlibi, static_cast<int>(launch_params.attention_mask_type),
launch_params.granular_tiling);
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
const CUfunction func = findIter->second.mDeviceFunction;

View File

@ -142,9 +142,17 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
{
case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
case CutlassGemmType::WeightOnly:
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
if (sm >= 75)
{
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
}
else
{
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64};
}
case CutlassGemmType::Int8:
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
@ -156,8 +164,8 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
}
}
std::vector<CutlassGemmConfig> get_candidate_configs(
int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
const bool int8_configs_only, const int max_split_k)
{
std::vector<CutlassTileConfig> tiles
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
@ -165,13 +173,20 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = int8_configs_only ? 3 : 2;
const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
for (const auto& tile_config : tiles)
{
for (int stages = min_stages; stages <= max_stages; ++stages)
{
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
candidate_configs.push_back(config);
if (sm >= 75)
{
for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor)
{
auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages};
candidate_configs.push_back(config);
}
}
}
}

View File

@ -26,8 +26,9 @@ namespace kernels
namespace cutlass_kernels
{
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(
int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false);
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false,
const int max_split_k = 1);
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
const std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig>& candidate_configs,

View File

@ -75,7 +75,7 @@ public:
protected:
static constexpr int SPLIT_K_LIMIT = 7;
static constexpr int MIN_M_TILE = 32;
static constexpr int MIN_N_TILE = 128;
static constexpr int MIN_N_TILE = 64;
};
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>

View File

@ -466,7 +466,8 @@ template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
std::vector<tkc::CutlassGemmConfig> CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::getConfigs() const
{
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
std::vector<tkc::CutlassGemmConfig> candidateConfigs = get_candidate_configs(sm_, is_weight_only, false);
std::vector<tkc::CutlassGemmConfig> candidateConfigs
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT);
return candidateConfigs;
}

View File

@ -368,7 +368,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassInt8GemmRunner<T>::getConfigs() const
static constexpr bool isWeightOnly = false;
std::vector<tkc::CutlassGemmConfig> candidateConfigs
= get_candidate_configs(mSm, isWeightOnly, mSm <= 70, /* SIMT configs */
true); /* INT8 configs */
true, SPLIT_K_LIMIT); /* INT8 configs */
return candidateConfigs;
}

View File

@ -48,7 +48,7 @@ namespace tensorrt_llm
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape, int Stages>
void genericMoeGemmKernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy = nullptr)
{
@ -152,11 +152,11 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
struct dispatch_stages
{
static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
TLLM_THROW("Cutlass fpA_intB gemm. Not instantiates for arch %d with stages set to %d",
TLLM_THROW("Cutlass fpA_intB gemm. Not instantiated for arch %d with stages set to %d",
arch::kMinComputeCapability, Stages);
}
};
@ -166,12 +166,12 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>
{
static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
genericMoeGemmKernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config,
weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config,
multi_processor_count, stream, occupancy);
}
};
@ -182,20 +182,20 @@ struct dispatch_stages<T, WeightType, cutlass::arch::Sm80, EpilogueTag, Threadbl
typename std::enable_if<(Stages > 2)>::type>
{
static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
genericMoeGemmKernelLauncher<T, WeightType, cutlass::arch::Sm80, EpilogueTag, ThreadblockShape, WarpShape,
Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config,
multi_processor_count, stream, occupancy);
Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts,
gemm_config, multi_processor_count, stream, occupancy);
}
};
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape>
void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
@ -203,17 +203,17 @@ void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales,
{
case 2:
using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k,
DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case 3:
using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k,
DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case 4:
using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k,
DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break;
@ -233,18 +233,18 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
{
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
@ -268,18 +268,18 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
{
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
@ -301,8 +301,8 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
{
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
break;
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
@ -356,6 +356,13 @@ void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(const T* A, const
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
stream, occupancy);
}
else if (sm_ >= 90)
{
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
stream, occupancy);
}
else
{
TLLM_THROW("Arch unsupported for MoE GEMM");

View File

@ -17,10 +17,12 @@
#include "decoderMaskedMultiheadAttentionTemplate.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include <algorithm>
#include <cuda_runtime_api.h>
#ifdef ENABLE_FP8
#include <cuda_fp8.h>
@ -91,24 +93,35 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_AT
template <typename T, int Dh, bool DO_CROSS_ATTENTION>
inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
int blocks_per_sm, int block_size, int tlength, bool do_multi_block)
int blocks_per_sm, int block_size, int tlength)
{
if (!do_multi_block)
if (!params.multi_block_mode)
{
params.multi_block_mode = false;
return;
}
params.seq_len_tile
int balanced_seq_len_tile
= mmha::divUp(params.multi_processor_count * blocks_per_sm, params.batch_size * params.num_heads);
params.seq_len_tile = std::max(params.min_seq_len_tile, params.seq_len_tile);
const int threads_per_value = mmha::threads_per_value<T>(mmha::dh_max(Dh));
// Make sure that each block at least processes one loop of kv (unroll size is default at 8).
const int seq_len_per_kv_loop = mmha::divUp(block_size, threads_per_value) * 8;
const int max_seq_len_tile = std::min(mmha::divUp(tlength + 1, seq_len_per_kv_loop), params.max_seq_len_tile);
int max_seq_len_tile = params.max_seq_len_tile;
params.seq_len_tile = std::min(params.seq_len_tile, max_seq_len_tile);
const bool multi_block_debug_flag = getEnvMmhaMultiblockDebug();
// User defined number of blocks.
if (multi_block_debug_flag)
{
const int env_seq_len_tile = getEnvMmhaBlocksPerSequence();
balanced_seq_len_tile = env_seq_len_tile > 0 ? env_seq_len_tile : balanced_seq_len_tile;
}
else
{
max_seq_len_tile = std::min(mmha::divUp(tlength + 1, seq_len_per_kv_loop), max_seq_len_tile);
}
params.seq_len_tile = std::clamp(balanced_seq_len_tile, params.min_seq_len_tile, max_seq_len_tile);
TLLM_CHECK_WITH_INFO(
params.seq_len_tile <= block_size, "The number of blocks per sequence may not exceed the thread block size.");
@ -118,6 +131,14 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
params.multi_block_mode = (params.seq_len_tile > 1);
static bool debug_flag_printed_once = false;
if (multi_block_debug_flag && !debug_flag_printed_once)
{
TLLM_LOG_INFO("MMHA kernel info: threads per block(%d), launched_blocks_per_sequence(%d), sequence_length(%d).",
block_size, params.seq_len_tile, tlength + 1);
debug_flag_printed_once = true;
}
grid.z = params.seq_len_tile;
}
@ -230,7 +251,7 @@ void mmha_launch_kernel_ex(
}
// If blocks with larger block size already fill all SMs, then disable the multi blocks mode.
mmha::multi_block_grid_setup<T, Dh>(grid, params, available_blocks, dynamic_block_size, tlength, DO_MULTI_BLOCK);
mmha::multi_block_grid_setup<T, Dh>(grid, params, available_blocks, dynamic_block_size, tlength);
// Launch kernels based on the valid block size.
switch (dynamic_block_size)

View File

@ -2549,7 +2549,7 @@ __global__ void masked_multihead_attention_kernel(
float final_max = -FLT_MAX;
float thread_partial_max = -FLT_MAX;
thread_partial_max = params.partial_max[bhi_seq_len_tile + min(tidx, gridDim.x - 1)];
thread_partial_max = params.partial_max[bhi_seq_len_tile + min(tidx, gridDim.z - 1)];
// Make sure we can start writing to shared memory.
__syncthreads();

View File

@ -33,14 +33,15 @@ namespace tensorrt_llm
namespace kernels
{
template <int M1, int N1, int K1, int M2, int N2, int K2>
void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
template <int M1, int N1, int K1, int M2, int N2, int K2, typename cutlassType>
void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
int64_t cublasWorkspaceSize, cudaStream_t stream)
int64_t cublasWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream)
{
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
using ElementA = cutlassType;
using ElementB = cutlassType;
using ElementOutput = cutlassType;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
@ -71,9 +72,10 @@ void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
float beta = 0.0f;
typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha, beta);
auto gemm_coord_size = tensorrt_llm::common::divUp(problem_count * sizeof(cutlass::gemm::GemmCoord), 16) * 16;
auto ptr_size = tensorrt_llm::common::divUp(problem_count * sizeof(half*), 16) * 16;
auto ldd_size = tensorrt_llm::common::divUp(problem_count * sizeof(int64_t), 16) * 16;
auto gemm_coord_size
= tensorrt_llm::common::divUp((size_t) problem_count * sizeof(cutlass::gemm::GemmCoord), (size_t) 16) * 16;
auto ptr_size = tensorrt_llm::common::divUp((size_t) problem_count * sizeof(half*), (size_t) 16) * 16;
auto ldd_size = tensorrt_llm::common::divUp((size_t) problem_count * sizeof(int64_t), (size_t) 16) * 16;
char* host_workspace = (char*) std::malloc(workSpaceSize);
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_workspace);
@ -135,24 +137,46 @@ void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
std::free(host_workspace);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
template <int M1, int N1, int K1, int M2, int N2, int K2>
void groupedGemmType_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
void* cublasWorkSpace, int64_t cublasWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream)
{
// For lora in, which has smaller N
run_cutlass_<128, 32, 32, 32, 32, 32>(
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
if (dataType == nvinfer1::DataType::kHALF)
{
groupedGemm_<M1, N1, K1, M2, N2, K2, cutlass::half_t>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace,
workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
}
else if (dataType == nvinfer1::DataType::kFLOAT)
{
TLLM_CHECK_WITH_INFO(false, "not support float input/output");
}
#ifdef ENABLE_BF16
else if (dataType == nvinfer1::DataType::kBF16)
{
groupedGemm_<M1, N1, K1, M2, N2, K2, cutlass::bfloat16_t>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace,
workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
}
#endif
}
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
void gropuedGemm(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
int64_t cublasWorkspaceSize, bool isLoraIn, nvinfer1::DataType dataType, cudaStream_t stream)
{
// For lora out, which has larger N
run_cutlass_<128, 128, 32, 64, 64, 32>(
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
if (isLoraIn)
{
groupedGemmType_<16, 32, 64, 16, 32, 64>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize,
cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
}
else
{
groupedGemmType_<32, 128, 32, 32, 32, 32>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize,
cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
}
}
} // namespace kernels

View File

@ -16,19 +16,16 @@
#pragma once
#include "cutlass/gemm_coord.h"
#include <NvInferRuntime.h>
namespace tensorrt_llm
{
namespace kernels
{
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
void gropuedGemm(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
int64_t cublasWorkspaceSize, bool isLoraIn, nvinfer1::DataType dataType, cudaStream_t stream);
} // namespace kernels

View File

@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "tensorrt_llm/common/workspace.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
@ -58,8 +59,9 @@ static constexpr int WARP_SIZE = 32;
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moeSoftmax(const T* input, const bool* finished, T* output, const int num_cols)
template <int TPB>
__launch_bounds__(TPB) __global__
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
{
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@ -81,7 +83,7 @@ __launch_bounds__(TPB) __global__ void moeSoftmax(const T* input, const bool* fi
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
threadData = max(input[idx], threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
@ -111,16 +113,16 @@ __launch_bounds__(TPB) __global__ void moeSoftmax(const T* input, const bool* fi
{
const int idx = thread_row_offset + ii;
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = T(val);
output[idx] = val;
}
}
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const T* inputs_after_softmax, const bool* finished, T* output,
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
{
using cub_kvp = cub::KeyValuePair<int, T>;
using cub_kvp = cub::KeyValuePair<int, float>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@ -135,7 +137,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(const T* inputs_after_softmax, co
for (int k_idx = 0; k_idx < k; ++k_idx)
{
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
@ -189,9 +191,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const T* inputs_after_softmax, co
2) This implementation assumes k is small, but will work for any k.
*/
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const T* input, const bool* finished, T* output, const int num_rows, int* indices,
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
@ -201,7 +203,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
@ -243,20 +245,20 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
using AccessType = cutlass::AlignedArray<T, ELTS_PER_LDG>;
using AccessType = cutlass::AlignedArray<float, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem
cutlass::Array<T, VPT> row_chunk_input;
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_input);
cutlass::Array<float, VPT> row_chunk;
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
@ -264,14 +266,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
using ComputeType = float;
using Converter = cutlass::NumericArrayConverter<ComputeType, T, VPT>;
Converter compute_type_converter;
cutlass::Array<ComputeType, VPT> row_chunk = compute_type_converter(row_chunk_input);
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
ComputeType thread_max = row_chunk[0];
float thread_max = row_chunk[0];
#pragma unroll
for (int ii = 1; ii < VPT; ++ii)
{
@ -370,7 +367,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
// single) thread per row of the input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = T(max_val);
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
}
@ -386,7 +383,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
{
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f);
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
}
}
}
@ -395,10 +392,10 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <typename T, int EXPERTS, int BYTES_PER_LDG>
template <int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
@ -407,28 +404,27 @@ struct TopkConstants
};
} // namespace detail
template <typename T, int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const T* input, const bool* finished, T* output, int* indices, int* source_row,
const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
}
template <typename T>
void topkGatingSoftmaxKernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output,
int* indices, int* source_row, const int num_rows, const int num_experts, const int k, const int start_expert,
const int end_expert, cudaStream_t stream)
void topkGatingSoftmaxKernelLauncher(const float* input, const bool* finished, float* output,
float* softmax_temp_output, int* indices, int* source_row, const int num_rows, const int num_experts, const int k,
const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr int WARPS_PER_TB = 4;
@ -436,55 +432,55 @@ void topkGatingSoftmaxKernelLauncher(const T* input, const bool* finished, T* ou
{
case 1:
{
topkGatingSoftmaxLauncherHelper<T, 1, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<1, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 2:
{
topkGatingSoftmaxLauncherHelper<T, 2, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<2, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 4:
{
topkGatingSoftmaxLauncherHelper<T, 4, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<4, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 8:
{
topkGatingSoftmaxLauncherHelper<T, 8, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<8, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 16:
{
topkGatingSoftmaxLauncherHelper<T, 16, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<16, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 32:
{
topkGatingSoftmaxLauncherHelper<T, 32, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<32, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 64:
{
topkGatingSoftmaxLauncherHelper<T, 64, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<64, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 128:
{
topkGatingSoftmaxLauncherHelper<T, 128, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<128, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
case 256:
{
topkGatingSoftmaxLauncherHelper<T, 256, WARPS_PER_TB>(
topkGatingSoftmaxLauncherHelper<256, WARPS_PER_TB>(
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
break;
}
@ -492,8 +488,8 @@ void topkGatingSoftmaxKernelLauncher(const T* input, const bool* finished, T* ou
{
static constexpr int TPB = 256;
TLLM_CHECK(softmax_temp_output != nullptr);
moeSoftmax<T, TPB><<<num_rows, TPB, 0, stream>>>(input, finished, softmax_temp_output, num_experts);
moeTopK<T, TPB><<<num_rows, TPB, 0, stream>>>(
moeSoftmax<TPB><<<num_rows, TPB, 0, stream>>>(input, finished, softmax_temp_output, num_experts);
moeTopK<TPB><<<num_rows, TPB, 0, stream>>>(
softmax_temp_output, finished, output, indices, source_row, num_experts, k, start_expert, end_expert);
}
}
@ -649,7 +645,7 @@ enum class ScaleMode : int
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection.
template <typename T, int RESIDUAL_NUM, bool HAS_BIAS, ScaleMode SCALE_MODE, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1,
const T* skip_2, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row,
const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row, const int cols, const int k, const int64_t* num_valid_ptr)
{
const int original_row = blockIdx.x;
@ -672,14 +668,14 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x)
{
T thread_output{0.f};
T row_rescale{0.f};
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx)
{
const int expanded_original_row = original_row + k_idx * num_rows;
const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
const int64_t k_offset = original_row * k + k_idx;
const T row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? T(1.f) : scales[k_offset];
const float row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
if constexpr (SCALE_MODE == ScaleMode::RENORM_SCALE)
{
row_rescale = row_rescale + row_scale;
@ -698,13 +694,14 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red
const T* bias_ptr = bias + expert_idx * cols;
const T bias_value = HAS_BIAS ? bias_ptr[tid] : T(0.f);
thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_value);
thread_output = static_cast<float>(thread_output)
+ row_scale * static_cast<float>(expanded_permuted_rows_row_ptr[tid] + bias_value);
}
if (SCALE_MODE == ScaleMode::RENORM_SCALE && (!CHECK_SKIPPED || thread_output))
{
assert(row_rescale != T(0.f));
thread_output = thread_output / row_rescale;
assert(row_rescale != 0.f);
thread_output = static_cast<float>(thread_output) / row_rescale;
}
if (RESIDUAL_NUM == 1)
@ -721,7 +718,7 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red
template <typename T, int RESIDUAL_NUM>
void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, T* reduced_unpermuted_output,
const T* skip_1, const T* skip_2, const T* bias, const T* scales,
const T* skip_1, const T* skip_2, const T* bias, const float* scales,
const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int num_rows,
const int cols, const int k, const int64_t* num_valid_ptr, MOEParallelismConfig parallelism_config,
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
@ -766,7 +763,7 @@ void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows,
template <typename T>
void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1,
const T* skip_2, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row,
const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row, const int num_rows, const int cols, const int k, const int64_t* num_valid_ptr,
MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
{
@ -834,6 +831,47 @@ void doGatedActivation(T* output, const T* gemm_result, const int64_t* num_valid
fn<<<blocks, threads, 0, stream>>>(output, gemm_result, num_valid_tokens_ptr, inter_size);
}
template <typename T, typename WeightType, typename Enable>
std::vector<size_t> CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceBufferSizes(const int num_rows,
const int hidden_size, const int inter_size, const int num_experts, const int num_experts_per_node, const int k,
ActivationType activation_type) const
{
const size_t num_moe_inputs = k * num_rows;
const size_t buf_size = num_moe_inputs * hidden_size;
const size_t interbuf_elems = num_moe_inputs * inter_size;
const size_t glu_inter_elems = isGatedActivation(activation_type) ? (interbuf_elems * 2) : 0;
int num_softmax_outs = 0;
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256)
{
num_softmax_outs = num_rows * num_experts;
}
size_t source_rows_size = num_moe_inputs * sizeof(int);
size_t permuted_rows_size = num_moe_inputs * sizeof(int);
size_t permuted_experts_size = num_moe_inputs * sizeof(int);
size_t permuted_data_size = buf_size * sizeof(T);
size_t total_rows_before_expert_size = num_experts_per_node * sizeof(int64_t);
size_t softmax_out_size = num_softmax_outs * sizeof(float);
size_t glu_inter_size = glu_inter_elems * sizeof(T);
size_t fc1_result_size = interbuf_elems * sizeof(T);
size_t sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts);
std::vector<size_t> workspace{
source_rows_size,
permuted_rows_size,
permuted_experts_size,
permuted_data_size,
total_rows_before_expert_size,
softmax_out_size,
glu_inter_size,
// These pointers reuse the same memory
std::max(fc1_result_size, sorter_size),
};
return workspace;
}
template <typename T, typename WeightType, typename Enable>
size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(const int num_rows, const int hidden_size,
const int inter_size, const int num_experts, const int k, ActivationType activation_type,
@ -841,38 +879,9 @@ size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(const int num
{
const int ep_size = parallelism_config.ep_size;
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size");
const int buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
const int glu_inter_size
= isGatedActivation(activation_type) ? pad_to_multiple_of_16(k * num_rows * inter_size * 2) : 0;
const int interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size) + glu_inter_size;
const int padded_experts = pad_to_multiple_of_16(num_experts / ep_size);
const int num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
int num_softmax_outs = 0;
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256)
{
num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts);
}
// softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them
// in Encoder or Decoder before invoking FfnLayer forward.
size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(T); // permuted_data
total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
total_ws_bytes += num_softmax_outs * sizeof(T);
const int bytes_for_fc1_result = interbuf_size * sizeof(T);
const int sorter_ws_size_bytes = pad_to_multiple_of_16(CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts));
int bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result)
{
int remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace
return total_ws_bytes;
auto workspace = getWorkspaceBufferSizes(
num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type);
return tensorrt_llm::common::calculateTotalWorkspaceSize(workspace.data(), workspace.size());
}
template <typename T, typename WeightType, typename Enable>
@ -880,45 +889,43 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::configureWsPtrs(char* ws_ptr, co
const int inter_size, const int num_experts, const int num_experts_per_node, const int k,
ActivationType activation_type)
{
const int buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
const int glu_inter_size
= isGatedActivation(activation_type) ? pad_to_multiple_of_16(k * num_rows * inter_size * 2) : 0;
const int interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size) + glu_inter_size;
const int sorter_ws_size_bytes = pad_to_multiple_of_16(CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts));
const size_t interbuf_sorter_size_bytes = std::max(interbuf_size * sizeof(T), (size_t) sorter_ws_size_bytes);
const int padded_experts = pad_to_multiple_of_16(num_experts_per_node);
const int num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
auto workspace = getWorkspaceBufferSizes(
num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, activation_type);
source_rows_ = (int*) ws_ptr;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
permuted_data_ = (T*) (permuted_experts_ + num_moe_inputs);
std::vector<int8_t*> ws_sliced{(int8_t*) ws_ptr};
for (auto size : workspace)
{
ws_sliced.push_back(nextWorkspacePtr(ws_sliced.back(), size));
}
total_rows_before_expert_ = (int64_t*) (permuted_data_ + buf_size);
source_rows_ = (int*) ws_sliced[0];
permuted_rows_ = (int*) ws_sliced[1];
permuted_experts_ = (int*) ws_sliced[2];
permuted_data_ = (T*) ws_sliced[3];
// These pointers are aliased. Since the sort ws can be overwritten after it is finished
glu_inter_result_ = (T*) (total_rows_before_expert_ + padded_experts);
fc1_result_ = glu_inter_result_ + glu_inter_size;
sorter_ws_ = (char*) glu_inter_result_;
total_rows_before_expert_ = (int64_t*) ws_sliced[4];
softmax_out_ = nullptr;
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256)
{
softmax_out_ = (T*) (sorter_ws_ + interbuf_sorter_size_bytes);
}
else
{
softmax_out_ = nullptr;
softmax_out_ = (float*) ws_sliced[5];
}
glu_inter_result_ = (T*) ws_sliced[6];
// These pointers are aliased. Since the sort ws can be overwritten after it is finished
sorter_ws_ = (char*) ws_sliced[7];
fc1_result_ = (T*) ws_sliced[7];
}
template <typename T, typename WeightType, typename Enable>
void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activations_void,
const void* gating_output_void, const void* fc1_expert_weights_void, const void* fc1_scales_void,
const void* fc1_expert_biases_void, ActivationType fc1_activation_type, const void* fc2_expert_weights_void,
const void* fc2_scales_void, const void* fc2_expert_biases_void, const int num_rows, const int hidden_size,
const int inter_size, const int num_experts, const int k, char* workspace_ptr, void* final_output_void,
void* fc2_result_void, const bool* finished, const int active_rows, void* expert_scales_void,
void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activations_void, const float* gating_output,
const void* fc1_expert_weights_void, const void* fc1_scales_void, const void* fc1_expert_biases_void,
ActivationType fc1_activation_type, const void* fc2_expert_weights_void, const void* fc2_scales_void,
const void* fc2_expert_biases_void, const int num_rows, const int hidden_size, const int inter_size,
const int num_experts, const int k, char* workspace_ptr, void* final_output_void, void* fc2_result_void,
const bool* finished, const int active_rows, void* expert_scales_void,
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, MOEParallelismConfig parallelism_config,
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
{
@ -926,7 +933,6 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
= std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value;
auto* input_activations = static_cast<const T*>(input_activations_void);
auto* gating_output = static_cast<const T*>(gating_output_void);
auto* fc1_expert_weights = static_cast<const WeightType*>(fc1_expert_weights_void);
auto* fc1_scales = static_cast<const T*>(fc1_scales_void);
auto* fc1_expert_biases = static_cast<const T*>(fc1_expert_biases_void);
@ -935,7 +941,7 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
auto* fc2_expert_biases = static_cast<const T*>(fc2_expert_biases_void);
auto* final_output = static_cast<T*>(final_output_void);
auto* fc2_result = static_cast<T*>(fc2_result_void);
auto* expert_scales = static_cast<T*>(expert_scales_void);
auto* expert_scales = static_cast<float*>(expert_scales_void);
TLLM_CHECK(input_activations);
TLLM_CHECK(gating_output);
@ -965,7 +971,7 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
configureWsPtrs(
workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, fc1_activation_type);
topkGatingSoftmaxKernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
topkGatingSoftmaxKernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
source_rows_, num_rows, num_experts, k, start_expert, end_expert, stream);
sync_check_cuda_error();
@ -1008,8 +1014,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
sync_check_cuda_error();
doGatedActivation<T>(
fc1_result_, glu_inter_result_, num_valid_tokens_ptr, inter_size, num_rows, fc1_activation_type, stream);
doGatedActivation<T>(fc1_result_, glu_inter_result_, num_valid_tokens_ptr, inter_size, num_rows * k,
fc1_activation_type, stream);
}
sync_check_cuda_error();

View File

@ -137,7 +137,7 @@ public:
virtual void setTactic(std::optional<cutlass_extensions::CutlassGemmConfig> gemm_config) = 0;
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() = 0;
virtual void runMoe(const void* input_activations, const void* gating_output, const void* fc1_expert_weights,
virtual void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights,
const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type,
const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows,
const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr,
@ -173,7 +173,7 @@ public:
return moe_gemm_runner_.getConfigs();
}
void runMoe(const void* input_activations, const void* gating_output, const void* fc1_expert_weights,
void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights,
const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type,
const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows,
const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr,
@ -185,6 +185,8 @@ public:
private:
void computeTotalRowsBeforeExpert(const int* sorted_indices, const int total_indices, const int num_experts,
int64_t* total_rows_before_expert, cudaStream_t stream);
std::vector<size_t> getWorkspaceBufferSizes(const int num_rows, const int hidden_size, const int inter_size,
const int num_experts, const int num_experts_per_node, const int k, ActivationType activation_type) const;
void configureWsPtrs(char* ws_ptr, const int num_rows, const int hidden_size, const int inter_size,
const int num_experts, const int num_experts_per_node, const int k, ActivationType activation_type);
@ -198,7 +200,7 @@ private:
int* permuted_experts_;
char* sorter_ws_;
T* permuted_data_;
T* softmax_out_;
float* softmax_out_;
int64_t* total_rows_before_expert_;
@ -224,7 +226,7 @@ public:
return;
}
void runMoe(const void* input_activations, const void* gating_output, const void* fc1_expert_weights,
void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights,
const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type,
const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows,
const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr,

View File

@ -209,8 +209,11 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const float* penalties, c
{
// outputIds shape: (batchSize, input_len + output_len)
int penaltyIndex = outputIds[batchIdx][blockIdx.y * maxSeqLen + index];
assert(penaltyIndex < vocabSize);
penaltyIndices[index] = penaltyIndex;
if (penaltyIndex >= vocabSize)
{
continue;
}
float logit = (float) logits[penaltyIndex];
if (penaltyType == RepetitionPenaltyType::Additive)
{
@ -239,6 +242,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const float* penalties, c
// Phase 2. Replace a logit value by the penalized one.
for (int index = threadIdx.x; index < currentStep; index += blockDim.x)
{
if (penaltyIndices[index] >= vocabSize)
{
continue;
}
logits[penaltyIndices[index]] = penaltyLogits[index];
}
}

View File

@ -37,10 +37,14 @@ __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_i
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
int bb_id = threadIdx.y + blockIdx.y * blockDim.y; // should be just blockIdx.y?
const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1
const int input_length{input_lengths == nullptr ? 0 : input_lengths[bb_id]};
const int batch_id = bb_id / beam_width;
const int beam_id = bb_id % beam_width;
if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_attention_window)
|| finished[bb_id].isFinished())
// Exit when the batch_beam or timestep is out of the bound.
// Assume that KV Cache is shared and fixed for context part,
// so we don't need to update the indices for context part.
if (bb_id >= beam_width * local_batch_size || time_step >= max_seq_len || time_step < input_length
|| time_step < (max_seq_len - max_attention_window) || finished[bb_id].isFinished())
{
return;
}
@ -90,14 +94,14 @@ BaseBeamSearchLayer<T>::BaseBeamSearchLayer(BaseBeamSearchLayer<T> const& beam_s
template <typename T>
BaseBeamSearchLayer<T>::~BaseBeamSearchLayer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
freeBuffer();
}
template <typename T>
void BaseBeamSearchLayer<T>::freeBuffer()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (is_allocate_buffer_)
{
allocator_->free((void**) (&temperature_buf_));
@ -105,26 +109,26 @@ void BaseBeamSearchLayer<T>::freeBuffer()
allocator_->free((void**) (&repetition_penalty_buf_));
is_allocate_buffer_ = false;
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void BaseBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
temperature_buf_ = allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false);
min_lengths_buf_ = allocator_->reMalloc(min_lengths_buf_, sizeof(int) * batch_size, false);
repetition_penalty_buf_ = allocator_->reMalloc(repetition_penalty_buf_, sizeof(float) * batch_size, false);
is_allocate_buffer_ = true;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void BaseBeamSearchLayer<T>::setupBase(size_t batch_size, SetupParams const& setupParams)
{
allocateBuffer(batch_size);
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// Setup penalties.
FillBuffers const fillBuffers{batch_size, stream_};
@ -149,13 +153,13 @@ void BaseBeamSearchLayer<T>::setupBase(size_t batch_size, SetupParams const& set
fillBuffers(setupParams.presence_penalty, 1.0f, mRepetitionPenalty, repetition_penalty_buf_);
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardParams const& params)
{
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
Tensor& output_ids_ptr = outputs.output_ids_ptr;
const auto batch_size = static_cast<std::int32_t>(output_ids_ptr.shape[0]);

View File

@ -33,7 +33,7 @@ namespace layers
template <typename T>
void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
curandstate_buf_ = allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false);
random_seeds_buf_ = allocator_->reMalloc(random_seeds_buf_, sizeof(unsigned long long) * batch_size, false);
temperature_buf_ = allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false);
@ -52,7 +52,7 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size)
template <typename T>
void BaseSamplingLayer<T>::freeBuffer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (is_allocate_buffer_)
{
allocator_->free((void**) (&curandstate_buf_));
@ -93,7 +93,7 @@ BaseSamplingLayer<T>::~BaseSamplingLayer()
template <typename T>
void BaseSamplingLayer<T>::setupBase(const size_t batch_size, SetupParams const& setupParams)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
allocateBuffer(batch_size);
// If runtime argument has single random seed, using this random seed to
@ -170,7 +170,7 @@ void BaseSamplingLayer<T>::setupBase(const size_t batch_size, SetupParams const&
template <typename T>
void BaseSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams const& params)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batch_size = outputs.output_ids_ptr.shape[0];
auto const local_batch_size = params.logits.shape[0];
@ -234,7 +234,7 @@ void BaseSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams
freeBuffer();
}
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template class BaseSamplingLayer<float>;

View File

@ -37,7 +37,7 @@ namespace layers
template <typename T>
void DynamicDecodeLayer<T>::initialize()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mOnlineBeamsearchDecode = std::make_unique<OnlineBeamSearchLayer<T>>(
vocab_size_, vocab_size_padded_, stream_, allocator_, is_free_buffer_after_forward_);
@ -58,14 +58,14 @@ DynamicDecodeLayer<T>::DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_p
, vocab_size_padded_(vocab_size_padded)
, cuda_device_prop_(cuda_device_prop)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
initialize();
}
template <typename T>
DynamicDecodeLayer<T>::~DynamicDecodeLayer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
freeBuffer();
}
@ -76,7 +76,7 @@ DynamicDecodeLayer<T>::DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_deco
, vocab_size_padded_(dynamic_decode_layer.vocab_size_padded_)
, cuda_device_prop_(dynamic_decode_layer.cuda_device_prop_)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
initialize();
}
@ -117,7 +117,7 @@ bool hasDiffRuntimeArgs(DecodingSetupParams const& params)
template <typename T>
void DynamicDecodeLayer<T>::setup(size_t batch_size, size_t beam_width, SetupParams const& setupParams)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (beam_width == 1)
{ // sampling layers
@ -159,7 +159,7 @@ void DynamicDecodeLayer<T>::setup(size_t batch_size, size_t beam_width, SetupPar
template <typename T>
void DynamicDecodeLayer<T>::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mIdsPtrHost->resize(2 * batch_size);
zero_parent_ids = allocator_->reMalloc(zero_parent_ids, sizeof(int*) * 2 * batch_size, false);
}
@ -167,14 +167,14 @@ void DynamicDecodeLayer<T>::allocateBuffer(size_t batch_size, size_t beam_width,
template <typename T>
void DynamicDecodeLayer<T>::freeBuffer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
allocator_->free((void**) &zero_parent_ids);
}
template <typename T>
void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const& params)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
const auto ite = params.ite;
const auto step = params.step;

View File

@ -97,7 +97,7 @@ void invokeUpdate(FinishedState* finished, int** parent_ids_ptr, int* sequence_l
template <typename T>
void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setupParams)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
BaseBeamSearchLayer<T>::setupBase(batch_size, setupParams);
allocateBuffer(batch_size);
@ -108,13 +108,13 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_);
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params)
{
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
Tensor const& output_ids_ptr = outputs.output_ids_ptr;
const auto batch_size = static_cast<std::int32_t>(output_ids_ptr.shape[0]);
const auto beam_width = static_cast<std::int32_t>(output_ids_ptr.shape[1]);
@ -159,7 +159,7 @@ void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, So
template <typename T>
void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// we need to check 2 * beam_width candidates each time
// 64 is the max beam width we support now.
topk_softmax_workspace_size_ = (size_t) (ceil(batch_size * 64 * (64 * 2) / 4.) * 4 * 2
@ -171,13 +171,13 @@ void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
length_penalties_buf_ = allocator_->reMalloc(length_penalties_buf_, sizeof(float) * batch_size, false);
is_allocate_buffer_ = true;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void OnlineBeamSearchLayer<T>::freeBuffer()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (is_allocate_buffer_)
{
allocator_->free((void**) (&topk_softmax_workspace_));
@ -185,7 +185,7 @@ void OnlineBeamSearchLayer<T>::freeBuffer()
allocator_->free((void**) (&length_penalties_buf_));
is_allocate_buffer_ = false;
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
@ -199,13 +199,13 @@ template <typename T>
OnlineBeamSearchLayer<T>::OnlineBeamSearchLayer(OnlineBeamSearchLayer<T> const& beam_search_layer)
: BaseBeamSearchLayer<T>(beam_search_layer)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
}
template <typename T>
OnlineBeamSearchLayer<T>::~OnlineBeamSearchLayer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
}
template class OnlineBeamSearchLayer<float>;

View File

@ -85,7 +85,7 @@ __global__ void setup_topk_runtime_args(int batch_size, uint32_t top_k, uint32_t
template <typename T>
void TopKSamplingLayer<T>::allocateBuffer(size_t const batch_size, std::vector<uint32_t> const& top_k)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
uint32_t max_top_k = (top_k.size() > 0) ? *std::max_element(std::begin(top_k), std::end(top_k)) : 1;
if (max_top_k == 0)
{
@ -104,7 +104,7 @@ void TopKSamplingLayer<T>::allocateBuffer(size_t const batch_size, std::vector<u
template <typename T>
void TopKSamplingLayer<T>::freeBuffer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (is_allocate_buffer_)
{
allocator_->free((void**) (&sampling_workspace_));
@ -118,7 +118,7 @@ void TopKSamplingLayer<T>::freeBuffer()
template <typename T>
void TopKSamplingLayer<T>::setup(size_t const batch_size, SetupParams const& setupParams)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
BaseSamplingLayer<T>::setupBase(batch_size, setupParams);
uint32_t const default_top_k = 0;
@ -162,7 +162,7 @@ void TopKSamplingLayer<T>::setup(size_t const batch_size, SetupParams const& set
template <typename T>
void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& params)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batch_size = outputs.output_ids_ptr.shape[0];
auto const local_batch_size = params.logits.shape[0];
@ -223,7 +223,7 @@ TopKSamplingLayer<T>::TopKSamplingLayer(TopKSamplingLayer<T> const& top_k_sampli
template <typename T>
TopKSamplingLayer<T>::~TopKSamplingLayer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
freeBuffer();
}

View File

@ -105,7 +105,7 @@ static __global__ void set_topp_runtime_args(int batch_size, std::uint32_t top_k
template <typename T>
void TopPSamplingLayer<T>::allocateBuffer(std::size_t batch_size, std::vector<float> const& top_p)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
float const max_top_p = (top_p.size() > 0) ? *std::max_element(std::begin(top_p), std::end(top_p)) : 0.0f;
invokeTopPSampling<T>(nullptr, // workspace
sampling_workspace_size_, cub_temp_storage_size_,
@ -136,7 +136,7 @@ void TopPSamplingLayer<T>::allocateBuffer(std::size_t batch_size, std::vector<fl
template <typename T>
void TopPSamplingLayer<T>::freeBuffer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (is_allocate_buffer_)
{
allocator_->free((void**) (&sampling_workspace_));
@ -157,7 +157,7 @@ void TopPSamplingLayer<T>::freeBuffer()
template <typename T>
void TopPSamplingLayer<T>::setup(std::size_t const batch_size, SetupParams const& setupParams)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
BaseSamplingLayer<T>::setupBase(batch_size, setupParams);
std::uint32_t const default_top_k = 0;
@ -229,7 +229,7 @@ void TopPSamplingLayer<T>::setup(std::size_t const batch_size, SetupParams const
template <typename T>
void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& params)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
auto const batch_size = outputs.output_ids_ptr.shape[0];
auto const local_batch_size = params.logits.shape[0];
@ -288,7 +288,7 @@ TopPSamplingLayer<T>::TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampli
template <typename T>
TopPSamplingLayer<T>::~TopPSamplingLayer()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
freeBuffer();
}

Some files were not shown because too many files have changed in this diff Show More