mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#667)
* Update TensorRT-LLM --------- Co-authored-by: 0xymoro <jerrymeng100@gmail.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
f7eca56161
commit
a75618df24
14
README.md
14
README.md
@ -150,10 +150,13 @@ mkdir -p ./bloom/560M && git clone https://huggingface.co/bigscience/bloom-560m
|
||||
```
|
||||
***2. Build the engine***
|
||||
|
||||
```python
|
||||
```bash
|
||||
# Single GPU on BLOOM 560M
|
||||
python build.py --model_dir ./bloom/560M/ \
|
||||
python convert_checkpoint.py --model_dir ./bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/
|
||||
# May need to add trtllm-build to PATH, export PATH=/usr/local/bin:$PATH
|
||||
trtllm-build --checkpoint_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/ \
|
||||
--use_gemm_plugin float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./bloom/560M/trt_engines/fp16/1-gpu/
|
||||
@ -166,7 +169,7 @@ See the BLOOM [example](examples/bloom) for more details and options regarding t
|
||||
The `../summarize.py` script can be used to perform the summarization of articles
|
||||
from the CNN Daily dataset:
|
||||
|
||||
```python
|
||||
```bash
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
@ -244,10 +247,12 @@ the models listed in the [examples](examples/.) folder.
|
||||
The list of supported models is:
|
||||
|
||||
* [Baichuan](examples/baichuan)
|
||||
* [BART](examples/enc_dec)
|
||||
* [Bert](examples/bert)
|
||||
* [Blip2](examples/blip2)
|
||||
* [BLOOM](examples/bloom)
|
||||
* [ChatGLM](examples/chatglm)
|
||||
* [FairSeq NMT](examples/nmt)
|
||||
* [Falcon](examples/falcon)
|
||||
* [Flan-T5](examples/enc_dec)
|
||||
* [GPT](examples/gpt)
|
||||
@ -257,6 +262,7 @@ The list of supported models is:
|
||||
* [InternLM](examples/internlm)
|
||||
* [LLaMA](examples/llama)
|
||||
* [LLaMA-v2](examples/llama)
|
||||
* [mBART](examples/enc_dec)
|
||||
* [Mistral](examples/llama#mistral-v01)
|
||||
* [MPT](examples/mpt)
|
||||
* [mT5](examples/enc_dec)
|
||||
@ -269,7 +275,7 @@ The list of supported models is:
|
||||
* [Whisper](examples/whisper)
|
||||
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
|
||||
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, etc. We
|
||||
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, NMT family, etc. We
|
||||
unroll the exact model names in the list above to let users find specific
|
||||
models easier.
|
||||
|
||||
|
||||
@ -96,7 +96,12 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
|
||||
// copy inputs and wrap into shared_ptr
|
||||
GenerationInput::TensorPtr inputIds;
|
||||
std::vector<int32_t> inputsHost(batchSize * maxInputLength, padId);
|
||||
std::vector<int32_t> inputsHost(batchSize * maxInputLength);
|
||||
srand(time(0));
|
||||
for (int i = 0; i < inputsHost.size(); i++)
|
||||
{
|
||||
inputsHost[i] = rand() % modelConfig.getVocabSizePadded(worldConfig.getSize());
|
||||
}
|
||||
|
||||
if (inputPacked)
|
||||
{
|
||||
|
||||
@ -12,14 +12,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class BuildConfig(BaseModel, extra=Extra.allow):
|
||||
@dataclass
|
||||
class BuildConfig:
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
hidden_size: int
|
||||
@ -41,20 +44,33 @@ class BuildConfig(BaseModel, extra=Extra.allow):
|
||||
enable_qk_half_accum: bool = False
|
||||
enable_context_fmha: bool = True
|
||||
enable_multi_block_mode: bool = False
|
||||
# The enum name of PositionEmbeddingType
|
||||
# None means using the model family's default value defined in the ctor
|
||||
position_embedding_type: Optional[PositionEmbeddingType] = None
|
||||
position_embedding_type: str = None
|
||||
# Only when position embedding is RoPE, this value makes sense, make
|
||||
# default value to be None, not 0 or 1 to prevent misuse
|
||||
# default value to be None, not the others to prevent misuse
|
||||
rotary_pct: Optional[float] = None
|
||||
bias: bool = True
|
||||
quantization: Optional[str] = None
|
||||
# use_custom_all_reduce gives better performance with NVLink
|
||||
use_custom_all_reduce: bool = True
|
||||
moe_num_experts: int = None
|
||||
moe_top_k: int = None
|
||||
moe_num_experts: int = 0
|
||||
moe_top_k: int = 0
|
||||
use_alibi: bool = None
|
||||
remove_input_padding: bool = None
|
||||
parallel_attention: bool = None
|
||||
new_decoder_architecture: bool = None
|
||||
|
||||
|
||||
class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
|
||||
@dataclass
|
||||
class EncDecBuildConfig:
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
hidden_size: int
|
||||
vocab_size: int
|
||||
hidden_act: Optional[str]
|
||||
n_positions: int
|
||||
max_batch_size: int
|
||||
num_decoder_layers: Optional[int] = None
|
||||
head_size: Optional[int] = None
|
||||
ffn_hidden_size: Optional[int] = None
|
||||
@ -62,6 +78,8 @@ class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
|
||||
max_distance: Optional[int] = None
|
||||
max_encoder_input_len: Optional[int] = None
|
||||
max_decoder_input_len: Optional[int] = None
|
||||
max_output_len: Optional[int] = None
|
||||
builder_opt: Optional[int] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.head_size is not None
|
||||
@ -69,7 +87,8 @@ class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
|
||||
assert self.num_buckets is not None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
name: str
|
||||
family: str
|
||||
benchmark_type: Literal["gpt", "bert", "enc_dec"]
|
||||
@ -192,7 +211,7 @@ _allowed_configs = {
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
|
||||
position_embedding_type='rope_gpt_neox',
|
||||
rotary_pct=0.5,
|
||||
bias=False,
|
||||
)),
|
||||
@ -285,6 +304,25 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_7b_moe":
|
||||
ModelConfig(name="llama_7b_moe",
|
||||
family="llama",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
moe_num_experts=4,
|
||||
moe_top_k=1,
|
||||
)),
|
||||
"llama_13b":
|
||||
ModelConfig(name="llama_13b",
|
||||
family="llama",
|
||||
@ -849,7 +887,7 @@ def get_allowed_models(benchmark_type=None):
|
||||
|
||||
def get_build_config(model_name):
|
||||
if model_name in _allowed_configs:
|
||||
return dict(_allowed_configs[model_name].build_config)
|
||||
return asdict(_allowed_configs[model_name].build_config)
|
||||
else:
|
||||
raise KeyError(f'Unexpected model: {model_name}. Please add the model '
|
||||
'to allowed_configs.py')
|
||||
@ -861,3 +899,11 @@ def get_model_family(model_name):
|
||||
else:
|
||||
raise KeyError(f'Unexpected model: {model_name}. Please add the model '
|
||||
'to allowed_configs.py')
|
||||
|
||||
|
||||
def get_benchmark_type(model_name):
|
||||
if model_name in _allowed_configs:
|
||||
return _allowed_configs[model_name].benchmark_type
|
||||
else:
|
||||
raise KeyError(f'Unexpected model: {model_name}. Please add the model '
|
||||
'to allowed_configs.py')
|
||||
|
||||
@ -124,6 +124,7 @@ class BaseBenchmark(object):
|
||||
report_dict["model_name"] = self.model_name
|
||||
report_dict["world_size"] = self.world_size
|
||||
report_dict["precision"] = self.dtype
|
||||
report_dict["quantization"] = str(self.quant_mode)
|
||||
report_dict["compute_cap"] = "sm" + get_compute_cap()
|
||||
return report_dict
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from base_benchmark import get_engine_name, serialize_engine
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_trt
|
||||
from tensorrt_llm.builder import Builder
|
||||
from tensorrt_llm.layers import PositionEmbeddingType
|
||||
from tensorrt_llm.layers import MoeConfig, PositionEmbeddingType
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models import PretrainedConfig, quantize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
@ -284,12 +284,12 @@ def build_gpt(args):
|
||||
apply_query_key_layer_scaling,
|
||||
position_embedding_type=PositionEmbeddingType.learned_absolute
|
||||
if build_config['position_embedding_type'] is None else
|
||||
build_config['position_embedding_type'],
|
||||
PositionEmbeddingType[build_config['position_embedding_type']],
|
||||
rotary_embedding_percentage=build_config['rotary_pct'],
|
||||
quant_mode=quant_mode,
|
||||
bias=build_config['bias'],
|
||||
moe_layer_config=tensorrt_llm.moe_config.MoeLayerConfig(
|
||||
build_config["moe_num_experts"], build_config["moe_top_k"]))
|
||||
moe_config=MoeConfig(build_config["moe_num_experts"],
|
||||
build_config["moe_top_k"]))
|
||||
elif family == "opt":
|
||||
config = {
|
||||
'architecture': 'OPTForCausalLM',
|
||||
@ -307,7 +307,29 @@ def build_gpt(args):
|
||||
'use_parallel_embedding': False,
|
||||
'share_embedding_table': False,
|
||||
'embedding_sharding_dim': 0,
|
||||
'do_layer_norm_before': build_config['do_layer_norm_before']
|
||||
'do_layer_norm_before': build_config['do_layer_norm_before'],
|
||||
'quantization': {
|
||||
'use_smooth_quant':
|
||||
quant_mode.has_act_and_weight_quant(),
|
||||
'per_channel':
|
||||
quant_mode.has_per_channel_scaling(),
|
||||
'per_token':
|
||||
quant_mode.has_per_token_dynamic_scaling(),
|
||||
'per_group':
|
||||
quant_mode.has_per_group_scaling(),
|
||||
'group_size':
|
||||
128,
|
||||
'int8_kv_cache':
|
||||
quant_mode.has_int8_kv_cache(),
|
||||
'enable_fp8':
|
||||
quant_mode.has_fp8_qdq(),
|
||||
'fp8_kv_cache':
|
||||
quant_mode.has_fp8_kv_cache(),
|
||||
'use_weight_only':
|
||||
quant_mode.is_weight_only(),
|
||||
'weight_only_precision':
|
||||
'int8' if quant_mode.is_int8_weight_only() else 'int4',
|
||||
}
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.OPTForCausalLM(config)
|
||||
@ -325,7 +347,9 @@ def build_gpt(args):
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
quant_mode=quant_mode,
|
||||
use_fused_mlp=True)
|
||||
use_fused_mlp=True,
|
||||
moe_config=MoeConfig(build_config["moe_num_experts"],
|
||||
build_config["moe_top_k"]))
|
||||
elif family == "gptj":
|
||||
tensorrt_llm_model = tensorrt_llm.models.GPTJForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
@ -402,17 +426,47 @@ def build_gpt(args):
|
||||
model_name="chatglm3_6b")
|
||||
|
||||
elif family == "bloom":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
quant_mode=quant_mode,
|
||||
use_parallel_embedding=(args.model == 'bloom_176b'))
|
||||
config = {
|
||||
'architecture': 'BloomForCausalLM',
|
||||
'dtype': args.dtype,
|
||||
'vocab_size': build_config['vocab_size'],
|
||||
'hidden_size': build_config['hidden_size'],
|
||||
'num_hidden_layers': build_config['num_layers'],
|
||||
'num_attention_heads': build_config['num_heads'],
|
||||
'hidden_act': build_config['hidden_act'],
|
||||
'max_position_embeddings': build_config['n_positions'],
|
||||
'mapping': {
|
||||
'world_size': world_size,
|
||||
'tp_size': world_size
|
||||
},
|
||||
'use_parallel_embedding': (args.model == 'bloom_176b'),
|
||||
'share_embedding_table': False,
|
||||
'embedding_sharding_dim': 0,
|
||||
'quantization': {
|
||||
'use_smooth_quant':
|
||||
quant_mode.has_act_and_weight_quant(),
|
||||
'per_channel':
|
||||
quant_mode.has_per_channel_scaling(),
|
||||
'per_token':
|
||||
quant_mode.has_per_token_dynamic_scaling(),
|
||||
'per_group':
|
||||
quant_mode.has_per_group_scaling(),
|
||||
'group_size':
|
||||
128,
|
||||
'int8_kv_cache':
|
||||
quant_mode.has_int8_kv_cache(),
|
||||
'enable_fp8':
|
||||
quant_mode.has_fp8_qdq(),
|
||||
'fp8_kv_cache':
|
||||
quant_mode.has_fp8_kv_cache(),
|
||||
'use_weight_only':
|
||||
quant_mode.is_weight_only(),
|
||||
'weight_only_precision':
|
||||
'int8' if quant_mode.is_int8_weight_only() else 'int4',
|
||||
}
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(config)
|
||||
elif family == "falcon":
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
@ -447,8 +501,9 @@ def build_gpt(args):
|
||||
"zero": True,
|
||||
"pre_quant_scale": False,
|
||||
}
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
|
||||
**quant_kwargs)
|
||||
if family not in ['opt', 'bloom']:
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
|
||||
**quant_kwargs)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
@ -495,7 +550,10 @@ def build_gpt(args):
|
||||
max_input_len,
|
||||
max_output_len, True,
|
||||
max_beam_width)
|
||||
tensorrt_llm_model(*inputs)
|
||||
if family in ['opt', 'bloom']:
|
||||
tensorrt_llm_model(**inputs)
|
||||
else:
|
||||
tensorrt_llm_model(*inputs)
|
||||
|
||||
if args.mode == 'plugin':
|
||||
tensorrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
@ -137,7 +137,7 @@ endif()
|
||||
|
||||
message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
enable_language(CUDA)
|
||||
enable_language(C CXX CUDA)
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
|
||||
@ -254,7 +254,7 @@ public:
|
||||
void startScheduling();
|
||||
|
||||
//! \brief Assign blocks for new sequence. Try to reuse blocks.
|
||||
void addSequence(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest);
|
||||
void addSequence(GenerationRequest& sequence, SizeType inputLength, std::shared_ptr<LlmRequest> const& llmRequest);
|
||||
|
||||
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
||||
void addSequence(GenerationRequest& sequence, SizeType inputLength, bool enableCyclicKvCache);
|
||||
@ -449,6 +449,17 @@ public:
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig,
|
||||
runtime::BufferManager const& bufferManager);
|
||||
|
||||
[[nodiscard]] SizeType getNumPrepopulatedTokens(SizeType batchSlotIdx, SizeType beamIdx) const
|
||||
{
|
||||
auto const& prepopulatedTokens = mSequences.at(batchSlotIdx)->getNumPrepopulatedTokens();
|
||||
return prepopulatedTokens.size() > 0 ? prepopulatedTokens.at(beamIdx) : 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isEnableBlockReuse() const
|
||||
{
|
||||
return mEnableBlockReuse;
|
||||
}
|
||||
|
||||
private:
|
||||
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
|
||||
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
|
||||
@ -34,21 +34,18 @@ public:
|
||||
|
||||
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
|
||||
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
|
||||
bool useContextFMHAForGeneration = false,
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt)
|
||||
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, maxNumSequences{maxNumSequences}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, useContextFMHAForGeneration(useContextFMHAForGeneration)
|
||||
, userSpecifiedDeviceIds(userSpecifiedDeviceIds)
|
||||
, deviceIds(deviceIds)
|
||||
{
|
||||
}
|
||||
|
||||
KvCacheConfig kvCacheConfig;
|
||||
std::optional<SizeType> maxNumSequences;
|
||||
bool enableTrtOverlap;
|
||||
bool useContextFMHAForGeneration;
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds;
|
||||
std::optional<std::vector<SizeType>> deviceIds;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -56,6 +56,8 @@ public:
|
||||
, mUseCustomAllReduce(false)
|
||||
, mMaxPromptEmbeddingTableSize(0)
|
||||
, mMaxDraftLen(0)
|
||||
, mUseContextFMHAForGeneration(false)
|
||||
, mPagedContextFMHA(false)
|
||||
{
|
||||
}
|
||||
|
||||
@ -280,6 +282,26 @@ public:
|
||||
return mMaxDraftLen + 1;
|
||||
}
|
||||
|
||||
void constexpr setUseContextFMHAForGeneration(bool useContextFMHAForGeneration) noexcept
|
||||
{
|
||||
mUseContextFMHAForGeneration = useContextFMHAForGeneration;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr getContextFMHAForGeneration() const noexcept
|
||||
{
|
||||
return mUseContextFMHAForGeneration;
|
||||
}
|
||||
|
||||
void constexpr setPagedContextFMHA(bool pagedContextFMHA) noexcept
|
||||
{
|
||||
mPagedContextFMHA = pagedContextFMHA;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr getPagedContextFMHA() const noexcept
|
||||
{
|
||||
return mPagedContextFMHA;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType mVocabSize;
|
||||
SizeType mNbLayers;
|
||||
@ -305,6 +327,9 @@ private:
|
||||
|
||||
SizeType mMaxPromptEmbeddingTableSize;
|
||||
SizeType mMaxDraftLen;
|
||||
|
||||
bool mUseContextFMHAForGeneration;
|
||||
bool mPagedContextFMHA;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -30,14 +30,8 @@ public:
|
||||
static SizeType constexpr kDefaultGpusPerNode = 8;
|
||||
|
||||
explicit WorldConfig(SizeType tensorParallelism = 1, SizeType pipelineParallelism = 1, SizeType rank = 0,
|
||||
SizeType gpusPerNode = kDefaultGpusPerNode, std::vector<SizeType> deviceIds = {})
|
||||
: mTensorParallelism{tensorParallelism}
|
||||
, mPipelineParallelism{pipelineParallelism}
|
||||
, mRank{rank}
|
||||
, mGpusPerNode{gpusPerNode}
|
||||
, mDeviceIds{deviceIds}
|
||||
{
|
||||
}
|
||||
SizeType gpusPerNode = kDefaultGpusPerNode,
|
||||
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt);
|
||||
|
||||
[[nodiscard]] SizeType constexpr getSize() const noexcept
|
||||
{
|
||||
@ -74,13 +68,14 @@ public:
|
||||
return mGpusPerNode;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getGpusPerGroup() const noexcept
|
||||
{
|
||||
return static_cast<SizeType>(mDeviceIds.size());
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getDevice() const noexcept
|
||||
{
|
||||
if (mDeviceIds.size())
|
||||
{
|
||||
return mDeviceIds[mRank % mGpusPerNode];
|
||||
}
|
||||
return mRank % mGpusPerNode;
|
||||
return mDeviceIds[mRank % getGpusPerGroup()];
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getPipelineParallelRank() const noexcept
|
||||
@ -116,12 +111,12 @@ public:
|
||||
static WorldConfig mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode = kDefaultGpusPerNode,
|
||||
std::optional<SizeType> tensorParallelism = std::nullopt,
|
||||
std::optional<SizeType> pipelineParallelism = std::nullopt,
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
|
||||
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt);
|
||||
|
||||
static WorldConfig mpi(SizeType gpusPerNode = kDefaultGpusPerNode,
|
||||
std::optional<SizeType> tensorParallelism = std::nullopt,
|
||||
std::optional<SizeType> pipelineParallelism = std::nullopt,
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
|
||||
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt);
|
||||
|
||||
private:
|
||||
SizeType mTensorParallelism;
|
||||
|
||||
@ -24,11 +24,11 @@ set(STATIC_TARGET
|
||||
set(API_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
find_package(MPI REQUIRED)
|
||||
message(STATUS "Using MPI_CXX_INCLUDE_DIRS: ${MPI_CXX_INCLUDE_DIRS}")
|
||||
message(STATUS "Using MPI_CXX_LIBRARIES: ${MPI_CXX_LIBRARIES}")
|
||||
message(STATUS "Using MPI_C_INCLUDE_DIRS: ${MPI_C_INCLUDE_DIRS}")
|
||||
message(STATUS "Using MPI_C_LIBRARIES: ${MPI_C_LIBRARIES}")
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cutlass_extensions/include
|
||||
${API_INCLUDE_DIR} ${MPI_INCLUDE_PATH})
|
||||
${API_INCLUDE_DIR} ${MPI_C_INCLUDE_DIRS})
|
||||
|
||||
add_subdirectory(common)
|
||||
add_subdirectory(kernels)
|
||||
@ -114,7 +114,7 @@ set(TRTLLM_LINK_LIBS
|
||||
${CUBLASLT_LIB}
|
||||
${CUDNN_LIB}
|
||||
${CMAKE_DL_LIBS}
|
||||
${MPI_CXX_LIBRARIES}
|
||||
${MPI_C_LIBRARIES}
|
||||
${NCCL_LIB}
|
||||
${TRT_LIB}
|
||||
common_src
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7d9f7d0f7dee2c48a424ff8873c2fd1298a27850f870657734641f2eb1190faf
|
||||
size 1791038
|
||||
oid sha256:51f905eed7ac6f5dbf12736519961100b8ac5f270cb96a79dd74c8f0a6837f24
|
||||
size 1801452
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fa79a0d563fc01a0cb2fe94dcb626ff4e5b736284d9244313cbe7aa0261dd48e
|
||||
size 1806500
|
||||
oid sha256:17ea5ea3b9cf666091a2997da2c7980e36a8b59b320668c5453886fb766c2db5
|
||||
size 1819266
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
d9723ab671c9fc3889cc624a58def81a libtensorrt_llm_batch_manager_static.a
|
||||
4b6773c990e8a59f1c716d88505b84a2 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
9a136bb59c51bbae09221c1667e23529ed05c752 commit
|
||||
d7a189cfdfbc3ebe45f07faf7be61434 libtensorrt_llm_batch_manager_static.a
|
||||
f18c7b1d843bbcc48a63fb0eedc3ade5 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
5218d255437b68c905d8be35a959405c20e6254c commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6a7b872fe6ee63a4342c3cd17b3557d74c72e537dbf0d4ddf132a2c40e000e57
|
||||
size 1709462
|
||||
oid sha256:1d2fd4c684ea3de95fb1070e28e2938760d024fec0bd5d710585c3f48c659b8f
|
||||
size 1721606
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c83f7c0e4fc22b32df669ada2b99b88f0f7faac935a251fe7a20030e2b364cc8
|
||||
size 1705432
|
||||
oid sha256:ec5f659d47742f96f36385cfac33a5a5d4159acf0ee029d9c01c56eb5afb9afc
|
||||
size 1715582
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
583141c3003a08acebc7054d024bee89 libtensorrt_llm_batch_manager_static.a
|
||||
03e5360f9b8074b8273500898581212f libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
723277a6fc05a0589b96a66893a8518b libtensorrt_llm_batch_manager_static.a
|
||||
40defd0eb709bea4c0b8e363fd264142 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -48,7 +48,7 @@ public:
|
||||
template <typename T>
|
||||
[[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, const bool setZero = true)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
// TODO martinma: why do we need this size extension?
|
||||
auto const sizeAligned = ((sizeBytes + 31) / 32) * 32; // make the buffer align with 32 bytes
|
||||
if (contains(ptr))
|
||||
|
||||
@ -48,7 +48,7 @@ ReallocType CudaAllocator::reallocType(void const* ptr, size_t size) const
|
||||
|
||||
void* CudaAllocator::malloc(std::size_t size, bool const setZero)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
auto bufferPtr = mBufferManager.gpu(size);
|
||||
if (setZero)
|
||||
{
|
||||
@ -62,7 +62,7 @@ void* CudaAllocator::malloc(std::size_t size, bool const setZero)
|
||||
|
||||
void CudaAllocator::free(void** ptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mPointerMapping.erase(*ptr);
|
||||
*ptr = nullptr;
|
||||
}
|
||||
|
||||
@ -297,9 +297,12 @@ inline int getMaxSharedMemoryPerBlockOptin()
|
||||
return max_shared_memory_per_block;
|
||||
}
|
||||
|
||||
inline int divUp(int a, int n)
|
||||
template <typename T1, typename T2>
|
||||
inline size_t divUp(const T1& a, const T2& n)
|
||||
{
|
||||
return (a + n - 1) / n;
|
||||
size_t tmp_a = static_cast<size_t>(a);
|
||||
size_t tmp_n = static_cast<size_t>(n);
|
||||
return (tmp_a + tmp_n - 1) / tmp_n;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename = std::enable_if_t<std::is_integral<T>::value>,
|
||||
|
||||
65
cpp/tensorrt_llm/common/envUtils.cpp
Normal file
65
cpp/tensorrt_llm/common/envUtils.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "envUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include <cstdlib>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug()
|
||||
{
|
||||
static bool init = false;
|
||||
static bool forceMmhaMaxSeqLenTile = false;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
const char* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
|
||||
if (enable_mmha_debug_var)
|
||||
{
|
||||
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
|
||||
{
|
||||
forceMmhaMaxSeqLenTile = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return forceMmhaMaxSeqLenTile;
|
||||
}
|
||||
|
||||
int getEnvMmhaBlocksPerSequence()
|
||||
{
|
||||
static bool init = false;
|
||||
static int mmhaBlocksPerSequence = 0;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
const char* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
|
||||
if (mmhaBlocksPerSequenceEnv)
|
||||
{
|
||||
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
|
||||
if (mmhaBlocksPerSequence <= 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return mmhaBlocksPerSequence;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
28
cpp/tensorrt_llm/common/envUtils.h
Normal file
28
cpp/tensorrt_llm/common/envUtils.h
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug();
|
||||
|
||||
int getEnvMmhaBlocksPerSequence();
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@ -233,7 +233,7 @@ public:
|
||||
template <typename T>
|
||||
inline T getVal(size_t index) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(where == MEMORY_CPU);
|
||||
TLLM_CHECK(data != nullptr);
|
||||
TLLM_CHECK_WITH_INFO(index < size(), "index is larger than buffer size");
|
||||
@ -249,7 +249,7 @@ public:
|
||||
template <typename T>
|
||||
inline T getVal() const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (getTensorType<T>() != type)
|
||||
{
|
||||
TLLM_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType<T>()).c_str(),
|
||||
@ -261,7 +261,7 @@ public:
|
||||
template <typename T>
|
||||
inline T* getPtr() const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (getTensorType<T>() != type)
|
||||
{
|
||||
TLLM_LOG_DEBUG("getPtr with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType<T>()).c_str(),
|
||||
@ -272,7 +272,7 @@ public:
|
||||
|
||||
inline void* getPtrWithOffset(size_t offset) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (data == nullptr)
|
||||
{
|
||||
return (void*) data;
|
||||
@ -287,7 +287,7 @@ public:
|
||||
template <typename T>
|
||||
inline T* getPtrWithOffset(size_t offset) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (getTensorType<T>() != type)
|
||||
{
|
||||
TLLM_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType<T>()).c_str(),
|
||||
@ -431,7 +431,7 @@ public:
|
||||
|
||||
inline bool contains(const std::string& key) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
return tensor_map_.find(key) != tensor_map_.end();
|
||||
}
|
||||
|
||||
@ -464,7 +464,7 @@ public:
|
||||
|
||||
inline Tensor& at(const std::string& key)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
TLLM_CHECK_WITH_INFO(contains(key),
|
||||
fmtstr(
|
||||
"Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str()));
|
||||
@ -489,7 +489,7 @@ public:
|
||||
|
||||
inline Tensor& at(const std::string& key, Tensor& default_tensor)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
if (contains(key))
|
||||
{
|
||||
return tensor_map_.at(key);
|
||||
@ -499,7 +499,7 @@ public:
|
||||
|
||||
inline Tensor at(const std::string& key, Tensor& default_tensor) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
if (contains(key))
|
||||
{
|
||||
return tensor_map_.at(key);
|
||||
@ -509,7 +509,7 @@ public:
|
||||
|
||||
inline Tensor& at(const std::string& key, Tensor&& default_tensor)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
if (contains(key))
|
||||
{
|
||||
return tensor_map_.at(key);
|
||||
|
||||
@ -50,7 +50,7 @@ TllmException::TllmException(char const* file, std::size_t line, const std::stri
|
||||
}
|
||||
#endif
|
||||
|
||||
TllmException::~TllmException() = default;
|
||||
TllmException::~TllmException() noexcept = default;
|
||||
|
||||
std::string TllmException::getTrace() const
|
||||
{
|
||||
|
||||
82
cpp/tensorrt_llm/common/workspace.h
Normal file
82
cpp/tensorrt_llm/common/workspace.h
Normal file
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::uintptr_t constexpr kCudaMemAlign = 128;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
int8_t* alignPtr(int8_t* ptr, uintptr_t to)
|
||||
{
|
||||
uintptr_t addr = (uintptr_t) ptr;
|
||||
if (addr % to)
|
||||
{
|
||||
addr += to - addr % to;
|
||||
}
|
||||
return (int8_t*) addr;
|
||||
}
|
||||
|
||||
int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, const uintptr_t alignment)
|
||||
{
|
||||
uintptr_t addr = (uintptr_t) ptr;
|
||||
addr += previousWorkspaceSize;
|
||||
return alignPtr((int8_t*) addr, alignment);
|
||||
}
|
||||
|
||||
int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
|
||||
{
|
||||
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign);
|
||||
}
|
||||
|
||||
int8_t* nextWorkspacePtr(
|
||||
int8_t* const base, uintptr_t& offset, const uintptr_t size, const uintptr_t alignment = kCudaMemAlign)
|
||||
{
|
||||
uintptr_t curr_offset = offset;
|
||||
uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment;
|
||||
int8_t* newptr = size == 0 ? nullptr : base + curr_offset;
|
||||
offset = next_offset;
|
||||
return newptr;
|
||||
}
|
||||
|
||||
int8_t* nextWorkspacePtrWithAlignment(
|
||||
int8_t* ptr, uintptr_t previousWorkspaceSize, const uintptr_t alignment = kCudaMemAlign)
|
||||
{
|
||||
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
|
||||
}
|
||||
|
||||
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count, const uintptr_t alignment = kCudaMemAlign)
|
||||
{
|
||||
size_t total = 0;
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
total += workspaces[i];
|
||||
if (workspaces[i] % alignment)
|
||||
{
|
||||
total += alignment - (workspaces[i] % alignment);
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
}; // namespace tensorrt_llm::common
|
||||
@ -44,7 +44,7 @@ struct MixedGemmArchTraits<float, float, arch>
|
||||
static constexpr int Stages = 2;
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using AccType = float;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 1;
|
||||
static constexpr int ElementsPerAccessB = 1;
|
||||
|
||||
@ -52,7 +52,7 @@ template <typename TypeB>
|
||||
struct LayoutDetailsB<TypeB, arch::Sm70>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
using Layout = layout::RowMajor;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 8;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
@ -63,7 +63,7 @@ template <typename Arch>
|
||||
struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
using Layout = layout::RowMajor;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
@ -72,7 +72,7 @@ template <typename Arch>
|
||||
struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
using Layout = layout::RowMajor;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
@ -507,6 +507,10 @@ public:
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -90,22 +90,24 @@ public:
|
||||
pagedKVXmmaKernel = getPagedKVXMMAKernelsV2(mDataType, sm);
|
||||
xmmaKernel = getXMMAKernelsV2(mDataType, sm);
|
||||
|
||||
params.clear();
|
||||
pagedKVParams.clear();
|
||||
mParams.clear();
|
||||
mPagedKVParams.clear();
|
||||
|
||||
// get device attributes
|
||||
int device_id;
|
||||
cudaGetDevice(&device_id);
|
||||
cudaDeviceGetAttribute(&launch_params.multi_processor_count, cudaDevAttrMultiProcessorCount, device_id);
|
||||
cudaDeviceGetAttribute(&launch_params.device_l2_cache_size, cudaDevAttrL2CacheSize, device_id);
|
||||
cudaDeviceGetAttribute(&mLaunchParams.multi_processor_count, cudaDevAttrMultiProcessorCount, device_id);
|
||||
cudaDeviceGetAttribute(&mLaunchParams.device_l2_cache_size, cudaDevAttrL2CacheSize, device_id);
|
||||
}
|
||||
|
||||
~mhaImpl() {}
|
||||
|
||||
// Support packed QKV.
|
||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
// Shared setup function.
|
||||
template <typename Params>
|
||||
void setup_params(Params& params, const int b, const int s_q, const int s_kv, const int sliding_window_size,
|
||||
const int total_seqlen, const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
|
||||
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
||||
// Note that we apply scales and bias in the order of
|
||||
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
|
||||
@ -114,79 +116,108 @@ public:
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
|
||||
Data_type scale_type = launch_params.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
||||
set_alpha(params.scale_bmm1, scale_bmm1, scale_type);
|
||||
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
||||
// Use specialized ws kernels on Hopper for cases without alibi.
|
||||
if (mLaunchParams.useKernelWithoutAlibi)
|
||||
{
|
||||
// The kernel adopts the log2f optimziation.
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
set_alpha(params.scale_bmm1, scale_bmm1 * float(kLog2e), DATA_TYPE_FP32);
|
||||
}
|
||||
else
|
||||
{
|
||||
set_alpha(params.scale_bmm1, scale_bmm1, scale_type);
|
||||
}
|
||||
set_alpha(params.scale_softmax, scale_softmax, scale_type);
|
||||
set_alpha(params.scale_bmm2, scale_bmm2, scale_type);
|
||||
|
||||
params.b = b;
|
||||
params.h = mNumHeads;
|
||||
params.s = s;
|
||||
params.s = s_q;
|
||||
params.d = mHeadSize;
|
||||
params.sliding_window_size = sliding_window_size;
|
||||
|
||||
params.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
|
||||
// Total sequence length needed by TMA descriptor
|
||||
// it should be actual total seq length if non-padded input is given.
|
||||
mTotalSeqLen = total_seqlen;
|
||||
|
||||
params.qkv_stride_in_bytes = (mNumHeads + 2 * params.h_kv) * mHeadSize * sizeof(half);
|
||||
params.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
|
||||
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
||||
const bool isSm90 = (sm == kSM_90);
|
||||
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
const bool isSm80 = (sm == kSM_80);
|
||||
if (isSm90 && params.d <= 64 && params.s <= 256)
|
||||
{
|
||||
launch_params.flash_attention = false;
|
||||
// get max sequence length for non-flash-attentio
|
||||
launch_params.kernel_s = getSFromMaxSeqLen(params.s);
|
||||
}
|
||||
else
|
||||
{ // always use flash attention kernels for Ampere/Ada
|
||||
launch_params.flash_attention = true;
|
||||
// flash attention kernles s = 0 (support any seq length)
|
||||
launch_params.kernel_s = 0;
|
||||
launch_params.force_unroll = true;
|
||||
// enable tiled kernels on Ampere/Ada
|
||||
if (launch_params.flash_attention && params.s <= 64)
|
||||
{
|
||||
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
|
||||
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
|
||||
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
|
||||
launch_params.granular_tiling = false;
|
||||
}
|
||||
else if (isSm8x && params.d < 256)
|
||||
{
|
||||
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
|
||||
launch_params.granular_tiling = false;
|
||||
}
|
||||
else if (isSm80 || isSm8x)
|
||||
{
|
||||
// otherwise, choose tiled kernel for Ampere/Ada
|
||||
launch_params.granular_tiling = true;
|
||||
}
|
||||
}
|
||||
|
||||
// when flash attention is enabled on Hopper, we need to set the tma descriptors
|
||||
if (isSm90 && launch_params.flash_attention)
|
||||
{
|
||||
launch_params.warp_specialization = true;
|
||||
launch_params.use_tma = true;
|
||||
}
|
||||
|
||||
// alibi.
|
||||
if (has_alibi)
|
||||
{
|
||||
params.has_alibi = true;
|
||||
params.alibi_params = AlibiParams(mNumHeads, s, tp_size, tp_rank, scale_after_alibi);
|
||||
params.alibi_params = AlibiParams(mNumHeads, s_kv, tp_size, tp_rank, scale_after_alibi);
|
||||
}
|
||||
}
|
||||
|
||||
// Support packed QKV.
|
||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
|
||||
// Determine launch parameters.
|
||||
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
||||
mLaunchParams.set_default_kernel_selection_params();
|
||||
|
||||
const bool isSm90 = (sm == kSM_90);
|
||||
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
const bool isSm80 = (sm == kSM_80);
|
||||
if (isSm90 && mHeadSize <= 64 && s <= 256)
|
||||
{
|
||||
mLaunchParams.flash_attention = false;
|
||||
// get max sequence length for non-flash-attentio
|
||||
mLaunchParams.kernel_s = getSFromMaxSeqLen(s);
|
||||
}
|
||||
else
|
||||
{ // always use flash attention kernels for Ampere/Ada
|
||||
mLaunchParams.flash_attention = true;
|
||||
// flash attention kernles s = 0 (support any seq length)
|
||||
mLaunchParams.kernel_s = 0;
|
||||
mLaunchParams.force_unroll = true;
|
||||
// enable tiled kernels on Ampere/Ada
|
||||
if (mLaunchParams.flash_attention && s <= 64)
|
||||
{
|
||||
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
|
||||
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
|
||||
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
|
||||
mLaunchParams.granular_tiling = false;
|
||||
}
|
||||
else if (isSm8x && mHeadSize < 256)
|
||||
{
|
||||
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
|
||||
mLaunchParams.granular_tiling = false;
|
||||
}
|
||||
else if (isSm80 || isSm8x)
|
||||
{
|
||||
// otherwise, choose tiled kernel for Ampere/Ada
|
||||
mLaunchParams.granular_tiling = true;
|
||||
}
|
||||
}
|
||||
|
||||
// when flash attention is enabled on Hopper, we need to set the tma descriptors
|
||||
if (isSm90 && mLaunchParams.flash_attention)
|
||||
{
|
||||
mLaunchParams.warp_specialization = true;
|
||||
mLaunchParams.use_tma = true;
|
||||
}
|
||||
|
||||
// Use specialized ws kernels on Hopper for cases without alibi.
|
||||
if (mLaunchParams.warp_specialization && !has_alibi)
|
||||
{
|
||||
// Use specialized ws kernels for cases without alibi.
|
||||
mLaunchParams.useKernelWithoutAlibi = true;
|
||||
}
|
||||
|
||||
// Sliding_window_causal mask.
|
||||
if (s > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL)
|
||||
if (s > sliding_window_size && mLaunchParams.attention_mask_type == ContextAttentionMaskType::CAUSAL)
|
||||
{
|
||||
params.sliding_window_size = sliding_window_size;
|
||||
launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
||||
mLaunchParams.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
||||
}
|
||||
|
||||
// Set kernel parameters.
|
||||
setup_params(mParams, b, s, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
mParams.qkv_stride_in_bytes = (mNumHeads + 2 * mParams.h_kv) * mHeadSize * sizeof(half);
|
||||
}
|
||||
|
||||
// Support paged_kv_cache and chunked_attention.
|
||||
@ -194,34 +225,13 @@ public:
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
||||
// Note that we apply scales and bias in the order of
|
||||
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
|
||||
const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
|
||||
const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
|
||||
Data_type scale_type = launch_params.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
||||
set_alpha(pagedKVParams.scale_bmm1, scale_bmm1, scale_type);
|
||||
set_alpha(pagedKVParams.scale_softmax, scale_softmax, scale_type);
|
||||
set_alpha(pagedKVParams.scale_bmm2, scale_bmm2, scale_type);
|
||||
|
||||
pagedKVParams.b = b;
|
||||
pagedKVParams.h = mNumHeads;
|
||||
pagedKVParams.s = s_q;
|
||||
pagedKVParams.d = mHeadSize;
|
||||
|
||||
// Total sequence length needed by TMA descriptor
|
||||
// it should be actual total seq length if non-padded input is given.
|
||||
mTotalSeqLen = total_seqlen;
|
||||
// Determine launch parameters.
|
||||
mLaunchParams.set_default_kernel_selection_params();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(tokens_per_kv_block >= 128, "FMHA with paged kv cache needs tokens_per_block >= 128 !");
|
||||
// Needed by TMA descriptors.
|
||||
launch_params.blocks_per_context_sequence = blocks_per_context_sequence;
|
||||
pagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
pagedKVParams.kv_stride_in_bytes = tokens_per_kv_block * mHeadSize * sizeof(half);
|
||||
pagedKVParams.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
mLaunchParams.blocks_per_context_sequence = blocks_per_context_sequence;
|
||||
|
||||
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
||||
const bool isSm90 = (sm == kSM_90);
|
||||
@ -229,53 +239,57 @@ public:
|
||||
const bool isSm80 = (sm == kSM_80);
|
||||
|
||||
// always use flash attention kernels.
|
||||
launch_params.flash_attention = true;
|
||||
mLaunchParams.flash_attention = true;
|
||||
// flash attention kernles s = 0 (support any seq length)
|
||||
launch_params.kernel_s = 0;
|
||||
launch_params.kernel_kv_s = s_kv;
|
||||
launch_params.force_unroll = true;
|
||||
mLaunchParams.kernel_s = 0;
|
||||
mLaunchParams.kernel_kv_s = s_kv;
|
||||
mLaunchParams.force_unroll = true;
|
||||
|
||||
// enable warp-specialization kernels when s > 512.
|
||||
if (isSm90 && s_kv > 512)
|
||||
{
|
||||
launch_params.warp_specialization = true;
|
||||
launch_params.use_tma = true;
|
||||
mLaunchParams.warp_specialization = true;
|
||||
mLaunchParams.use_tma = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// enable tiled kernels on Ampere/Ada
|
||||
if (launch_params.flash_attention && s_kv <= 64)
|
||||
if (mLaunchParams.flash_attention && s_kv <= 64)
|
||||
{
|
||||
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
|
||||
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
|
||||
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
|
||||
launch_params.granular_tiling = false;
|
||||
mLaunchParams.granular_tiling = false;
|
||||
}
|
||||
else if (isSm8x && params.d < 256)
|
||||
else if (isSm8x && mParams.d < 256)
|
||||
{
|
||||
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
|
||||
launch_params.granular_tiling = false;
|
||||
mLaunchParams.granular_tiling = false;
|
||||
}
|
||||
else if (isSm90 || isSm80 || isSm8x)
|
||||
{
|
||||
// otherwise, choose tiled kernel for Ampere/Ada
|
||||
launch_params.granular_tiling = true;
|
||||
mLaunchParams.granular_tiling = true;
|
||||
}
|
||||
}
|
||||
|
||||
// alibi.
|
||||
if (has_alibi)
|
||||
// Use specialized ws kernels on Hopper for cases without alibi.
|
||||
if (mLaunchParams.warp_specialization && !has_alibi)
|
||||
{
|
||||
pagedKVParams.has_alibi = true;
|
||||
pagedKVParams.alibi_params = AlibiParams(mNumHeads, s_kv, tp_size, tp_rank, scale_after_alibi);
|
||||
// Use specialized ws kernels for cases without alibi.
|
||||
mLaunchParams.useKernelWithoutAlibi = true;
|
||||
}
|
||||
|
||||
// Sliding_window_causal mask.
|
||||
if (s_kv > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL)
|
||||
if (s_kv > sliding_window_size && mLaunchParams.attention_mask_type == ContextAttentionMaskType::CAUSAL)
|
||||
{
|
||||
pagedKVParams.sliding_window_size = sliding_window_size;
|
||||
launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
||||
mLaunchParams.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
||||
}
|
||||
|
||||
setup_params(
|
||||
mPagedKVParams, b, s_q, s_kv, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
mPagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
mPagedKVParams.kv_stride_in_bytes = tokens_per_kv_block * mHeadSize * sizeof(half);
|
||||
}
|
||||
|
||||
// NOTE: assume that heads_interleaved = false (b, s, 3, h, d), and sequences are padded/non-padded
|
||||
@ -283,7 +297,7 @@ public:
|
||||
void set_tma_descriptors()
|
||||
{
|
||||
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
||||
const uint32_t d_in_bytes = params.d * sizeof(uint16_t);
|
||||
const uint32_t d_in_bytes = mParams.d * sizeof(uint16_t);
|
||||
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
|
||||
// separate q, k, and v tma descriptors
|
||||
@ -291,18 +305,18 @@ public:
|
||||
|
||||
// tensor size
|
||||
uint32_t tensor_size_qkv[4];
|
||||
if (params.h_kv < params.h)
|
||||
if (mParams.h_kv < mParams.h)
|
||||
{
|
||||
// if multi-query or grouped-query
|
||||
tensor_size_qkv[2] = 1;
|
||||
tensor_size_qkv[1] = (params.h + 2 * params.h_kv);
|
||||
tensor_size_qkv[0] = params.d; // params.d;
|
||||
tensor_size_qkv[1] = (mParams.h + 2 * mParams.h_kv);
|
||||
tensor_size_qkv[0] = mParams.d; // mParams.d;
|
||||
}
|
||||
else
|
||||
{
|
||||
tensor_size_qkv[2] = 3;
|
||||
tensor_size_qkv[1] = params.h;
|
||||
tensor_size_qkv[0] = params.d; // params.d;
|
||||
tensor_size_qkv[1] = mParams.h;
|
||||
tensor_size_qkv[0] = mParams.d; // mParams.d;
|
||||
}
|
||||
|
||||
// box size for k and v
|
||||
@ -310,7 +324,7 @@ public:
|
||||
// Update this on device?
|
||||
box_size[2] = 1;
|
||||
box_size[1] = 1;
|
||||
box_size[0] = params.d / d_groups;
|
||||
box_size[0] = mParams.d / d_groups;
|
||||
|
||||
// stride size in bytes. Assumes least significant dim is 1 (?)
|
||||
uint64_t tensor_stride_qkv[3];
|
||||
@ -328,7 +342,7 @@ public:
|
||||
uint32_t fp32_to_tf32 = 0;
|
||||
|
||||
// gmma descriptor mode
|
||||
const uint32_t d_bytes_per_group = (params.d * sizeof(uint16_t)) / d_groups;
|
||||
const uint32_t d_bytes_per_group = (mParams.d * sizeof(uint16_t)) / d_groups;
|
||||
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
|
||||
? cudaTmaDescSwizzle::SWIZZLE_128B
|
||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||
@ -336,7 +350,7 @@ public:
|
||||
uint32_t q_step = 0, kv_step = 0;
|
||||
for (unsigned int i = 0u; i < sizeof(sTmaMetaInfo) / sizeof(sTmaMetaInfo[0]); ++i)
|
||||
{
|
||||
if (sTmaMetaInfo[i].mD == params.d)
|
||||
if (sTmaMetaInfo[i].mD == mParams.d)
|
||||
{
|
||||
q_step = sTmaMetaInfo[i].mQStep;
|
||||
kv_step = sTmaMetaInfo[i].mKVStep;
|
||||
@ -346,7 +360,7 @@ public:
|
||||
|
||||
// QKV [TOTAL, 3, h, d]
|
||||
// NOTE: we may need to use actual seqlen to set oob_value
|
||||
const char* qkv_ptr = reinterpret_cast<const char*>(params.qkv_ptr);
|
||||
const char* qkv_ptr = reinterpret_cast<const char*>(mParams.qkv_ptr);
|
||||
tensor_size_qkv[3] = mTotalSeqLen;
|
||||
|
||||
// Q: STEP_Q
|
||||
@ -354,18 +368,18 @@ public:
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
|
||||
¶ms.tma_desc_q);
|
||||
&mParams.tma_desc_q);
|
||||
|
||||
// K/V: STEP_KV
|
||||
box_size[3] = kv_step;
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
|
||||
¶ms.tma_desc_k);
|
||||
&mParams.tma_desc_k);
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
|
||||
¶ms.tma_desc_v);
|
||||
&mParams.tma_desc_v);
|
||||
}
|
||||
|
||||
// Q are contiguous in the shape of [B, S, H, D]
|
||||
@ -375,13 +389,13 @@ public:
|
||||
void set_paged_kv_tma_descriptors(cudaStream_t stream)
|
||||
{
|
||||
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
||||
const uint32_t d_in_bytes = pagedKVParams.d * sizeof(uint16_t);
|
||||
const uint32_t d_in_bytes = mPagedKVParams.d * sizeof(uint16_t);
|
||||
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
|
||||
uint32_t q_step = 0, kv_step = 0;
|
||||
for (unsigned int i = 0u; i < sizeof(sTmaPagedKVMetaInfo) / sizeof(sTmaPagedKVMetaInfo[0]); ++i)
|
||||
{
|
||||
if (sTmaPagedKVMetaInfo[i].mD == pagedKVParams.d)
|
||||
if (sTmaPagedKVMetaInfo[i].mD == mPagedKVParams.d)
|
||||
{
|
||||
q_step = sTmaPagedKVMetaInfo[i].mQStep;
|
||||
kv_step = sTmaPagedKVMetaInfo[i].mKVStep;
|
||||
@ -392,21 +406,21 @@ public:
|
||||
// Separate q, and paged kv tma descriptors.
|
||||
Multiple_tma_descriptor<4> q_tma_descriptor;
|
||||
Multiple_tma_descriptor<4> paged_kv_tma_descriptor(
|
||||
pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence);
|
||||
mPagedKVParams.b * 2 * mLaunchParams.blocks_per_context_sequence);
|
||||
// Contiguous Q
|
||||
// query tensor size [B x S, 1, H, D]
|
||||
uint32_t tensor_size_q[4];
|
||||
tensor_size_q[3] = mTotalSeqLen;
|
||||
tensor_size_q[2] = 1;
|
||||
tensor_size_q[1] = pagedKVParams.h;
|
||||
tensor_size_q[0] = pagedKVParams.d;
|
||||
tensor_size_q[1] = mPagedKVParams.h;
|
||||
tensor_size_q[0] = mPagedKVParams.d;
|
||||
|
||||
// box size for k and v
|
||||
uint32_t box_size_q[4];
|
||||
box_size_q[3] = q_step;
|
||||
box_size_q[2] = 1;
|
||||
box_size_q[1] = 1;
|
||||
box_size_q[0] = pagedKVParams.d / d_groups;
|
||||
box_size_q[0] = mPagedKVParams.d / d_groups;
|
||||
|
||||
// stride size in bytes.
|
||||
uint64_t tensor_stride_q[3];
|
||||
@ -424,34 +438,34 @@ public:
|
||||
uint32_t fp32_to_tf32 = 0;
|
||||
|
||||
// gmma descriptor mode
|
||||
const uint32_t d_bytes_per_group = (pagedKVParams.d * sizeof(uint16_t)) / d_groups;
|
||||
const uint32_t d_bytes_per_group = (mPagedKVParams.d * sizeof(uint16_t)) / d_groups;
|
||||
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
|
||||
? cudaTmaDescSwizzle::SWIZZLE_128B
|
||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||
|
||||
// Q ptr.
|
||||
const char* q_ptr = reinterpret_cast<const char*>(pagedKVParams.q_ptr);
|
||||
const char* q_ptr = reinterpret_cast<const char*>(mPagedKVParams.q_ptr);
|
||||
|
||||
// Q: STEP_Q.
|
||||
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_q, tensor_stride_q, traversal_stride, box_size_q, oob_fill, fp32_to_tf32,
|
||||
&pagedKVParams.tma_desc_q);
|
||||
&mPagedKVParams.tma_desc_q);
|
||||
|
||||
// Paged KV
|
||||
// Per batch tensor size.
|
||||
uint32_t tensor_size_kv[4];
|
||||
tensor_size_kv[3] = 1;
|
||||
tensor_size_kv[2] = pagedKVParams.h_kv;
|
||||
tensor_size_kv[1] = pagedKVParams.paged_kv_cache.mTokensPerBlock;
|
||||
tensor_size_kv[0] = pagedKVParams.d;
|
||||
tensor_size_kv[2] = mPagedKVParams.h_kv;
|
||||
tensor_size_kv[1] = mPagedKVParams.paged_kv_cache.mTokensPerBlock;
|
||||
tensor_size_kv[0] = mPagedKVParams.d;
|
||||
|
||||
// Box size for k and v.
|
||||
uint32_t box_size_kv[4];
|
||||
box_size_kv[3] = 1;
|
||||
box_size_kv[2] = 1;
|
||||
box_size_kv[1] = kv_step;
|
||||
box_size_kv[0] = pagedKVParams.d / d_groups;
|
||||
box_size_kv[0] = mPagedKVParams.d / d_groups;
|
||||
|
||||
// Stride size in bytes.
|
||||
uint64_t tensor_stride_kv[3];
|
||||
@ -461,40 +475,40 @@ public:
|
||||
|
||||
// 2 stands for k, and v blocks.
|
||||
// We only need to prepare as many tma descriptos as the number of paged kv blocks for context.
|
||||
for (int block_idx = 0; block_idx < pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence;
|
||||
for (int block_idx = 0; block_idx < mPagedKVParams.b * 2 * mLaunchParams.blocks_per_context_sequence;
|
||||
block_idx++)
|
||||
{
|
||||
int block_ptr_idx = int(block_idx / launch_params.blocks_per_context_sequence)
|
||||
* pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq
|
||||
+ (block_idx % launch_params.blocks_per_context_sequence);
|
||||
int block_ptr_idx = int(block_idx / mLaunchParams.blocks_per_context_sequence)
|
||||
* mPagedKVParams.paged_kv_cache.mMaxBlocksPerSeq
|
||||
+ (block_idx % mLaunchParams.blocks_per_context_sequence);
|
||||
paged_kv_tma_descriptor.set_tma_desctriptor(
|
||||
reinterpret_cast<char*>(launch_params.paged_kv_block_ptrs[block_ptr_idx]), cudaTmaDescFormat::F16_RN,
|
||||
reinterpret_cast<char*>(mLaunchParams.paged_kv_block_ptrs[block_ptr_idx]), cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_kv, tensor_stride_kv, traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, block_idx);
|
||||
}
|
||||
|
||||
// set mMaxBlocksPerSeq to the number of blocks needed for context.
|
||||
pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq = launch_params.blocks_per_context_sequence;
|
||||
mPagedKVParams.paged_kv_cache.mMaxBlocksPerSeq = mLaunchParams.blocks_per_context_sequence;
|
||||
|
||||
paged_kv_tma_descriptor.copy_to_device(pagedKVParams.tma_desc_paged_kv, stream);
|
||||
paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream);
|
||||
}
|
||||
|
||||
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
|
||||
{
|
||||
// BF16 FMHA only accumulates on FP32
|
||||
launch_params.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
||||
launch_params.attention_mask_type
|
||||
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
||||
mLaunchParams.attention_mask_type
|
||||
= causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING;
|
||||
|
||||
// Paged KV Cache.
|
||||
pagedKVParams.h_kv = num_kv_heads;
|
||||
mPagedKVParams.h_kv = num_kv_heads;
|
||||
TLLM_CHECK_WITH_INFO(mNumHeads % num_kv_heads == 0, "number of Query heads should be multiple of KV heads !");
|
||||
pagedKVParams.h_q_per_kv = mNumHeads / num_kv_heads;
|
||||
pagedKVParams.is_s_padded = is_s_padded;
|
||||
mPagedKVParams.h_q_per_kv = mNumHeads / num_kv_heads;
|
||||
mPagedKVParams.is_s_padded = is_s_padded;
|
||||
|
||||
// Contiguous Cache.
|
||||
params.h_kv = num_kv_heads;
|
||||
params.is_s_padded = is_s_padded;
|
||||
mParams.h_kv = num_kv_heads;
|
||||
mParams.is_s_padded = is_s_padded;
|
||||
}
|
||||
|
||||
bool fmha_supported()
|
||||
@ -504,39 +518,39 @@ public:
|
||||
|
||||
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||
{
|
||||
params.qkv_ptr = qkvPtr;
|
||||
params.o_ptr = outputPtr;
|
||||
params.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
|
||||
mParams.qkv_ptr = qkvPtr;
|
||||
mParams.o_ptr = outputPtr;
|
||||
mParams.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
|
||||
|
||||
if (sm == kSM_90 && launch_params.use_tma)
|
||||
if (sm == kSM_90 && mLaunchParams.use_tma)
|
||||
{
|
||||
// memcpy H2D has been removed by applying grid_constant tma descriptors.
|
||||
set_tma_descriptors();
|
||||
}
|
||||
|
||||
xmmaKernel->run(params, launch_params, stream);
|
||||
xmmaKernel->run(mParams, mLaunchParams, stream);
|
||||
}
|
||||
|
||||
void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
|
||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
pagedKVParams.q_ptr = qPtr;
|
||||
pagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
|
||||
pagedKVParams.paged_kv_cache = pagedKVCache;
|
||||
pagedKVParams.o_ptr = outputPtr;
|
||||
pagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
|
||||
pagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
|
||||
mPagedKVParams.q_ptr = qPtr;
|
||||
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
|
||||
mPagedKVParams.paged_kv_cache = pagedKVCache;
|
||||
mPagedKVParams.o_ptr = outputPtr;
|
||||
mPagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
|
||||
mPagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
|
||||
// paged kv block device ptrs on host (used by tma descriptors).
|
||||
launch_params.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
|
||||
mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
|
||||
|
||||
if (sm == kSM_90 && launch_params.use_tma)
|
||||
if (sm == kSM_90 && mLaunchParams.use_tma)
|
||||
{
|
||||
// memcpy H2D is needed as we use multiple tma descriptors in device memory.
|
||||
set_paged_kv_tma_descriptors(stream);
|
||||
}
|
||||
|
||||
pagedKVXmmaKernel->run(pagedKVParams, launch_params, stream);
|
||||
pagedKVXmmaKernel->run(mPagedKVParams, mLaunchParams, stream);
|
||||
}
|
||||
|
||||
bool isValid(int s) const
|
||||
@ -578,9 +592,9 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
Fused_multihead_attention_params_v2 params;
|
||||
Fused_multihead_attention_paged_kv_params_v2 pagedKVParams;
|
||||
Launch_params launch_params;
|
||||
Fused_multihead_attention_params_v2 mParams;
|
||||
Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams;
|
||||
Launch_params mLaunchParams;
|
||||
int sm;
|
||||
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
|
||||
const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel;
|
||||
|
||||
@ -301,9 +301,26 @@ struct Launch_params
|
||||
bool granular_tiling = false;
|
||||
// mask type: padding, causal, sliding_window_causal
|
||||
ContextAttentionMaskType attention_mask_type = ContextAttentionMaskType::PADDING;
|
||||
// use specialized kernels without alibi support.
|
||||
bool useKernelWithoutAlibi = false;
|
||||
// harward properties to determine how to launch blocks
|
||||
int multi_processor_count = 0;
|
||||
int device_l2_cache_size = 0;
|
||||
|
||||
void set_default_kernel_selection_params()
|
||||
{
|
||||
kernel_s = 0;
|
||||
kernel_kv_s = 0;
|
||||
force_unroll = false;
|
||||
use_tma = false;
|
||||
flash_attention = false;
|
||||
warp_specialization = false;
|
||||
granular_tiling = false;
|
||||
attention_mask_type = (attention_mask_type == ContextAttentionMaskType::PADDING)
|
||||
? ContextAttentionMaskType::PADDING
|
||||
: ContextAttentionMaskType::CAUSAL;
|
||||
useKernelWithoutAlibi = false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -232,20 +232,21 @@ public:
|
||||
}
|
||||
|
||||
inline uint64_t hashID(unsigned int s, unsigned int d, bool interleaved, bool unroll, bool force_fp32_acc,
|
||||
bool flash_attention, int attention_mask_type, bool tiled) const
|
||||
bool flash_attention, bool is_alibi_supported, int attention_mask_type, bool tiled) const
|
||||
{
|
||||
s = flash_attention ? 0 : s;
|
||||
// D <= 2048
|
||||
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 5) | (tiled ? 16ull : 0ull)
|
||||
| (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull)
|
||||
| (unroll ? 1ull : 0ull);
|
||||
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 6) | (is_alibi_supported ? 32ull : 0ull)
|
||||
| (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull)
|
||||
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||
}
|
||||
|
||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
||||
{
|
||||
|
||||
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
||||
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
|
||||
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mAlibiSupported,
|
||||
kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
|
||||
}
|
||||
|
||||
virtual void run(
|
||||
@ -288,10 +289,17 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
|
||||
forceUnroll, launch_params.force_fp32_acc, launch_params.flash_attention,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "FMHA kernels are not found");
|
||||
const auto findIter
|
||||
= mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll,
|
||||
launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||
|
||||
// Add debug info when kernels are not found.
|
||||
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(),
|
||||
"FMHA kernels are not found (kernel meta info: %d %d %d %d %d %d %d %d %d) !", launch_params.kernel_s,
|
||||
params.d, launch_params.interleaved, forceUnroll, launch_params.force_fp32_acc,
|
||||
launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling);
|
||||
|
||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
@ -383,31 +391,39 @@ public:
|
||||
}
|
||||
|
||||
inline uint64_t hashID(unsigned int s, unsigned int d, bool interleaved, bool unroll, bool force_fp32_acc,
|
||||
bool flash_attention, bool warp_specialization, int attention_mask_type, bool tiled) const
|
||||
bool flash_attention, bool warp_specialization, bool is_alibi_supported, int attention_mask_type,
|
||||
bool tiled) const
|
||||
{
|
||||
s = flash_attention ? 0 : s;
|
||||
// D <= 2048
|
||||
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 6) | (warp_specialization ? 16ull : 0ull)
|
||||
| (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull)
|
||||
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 7) | (is_alibi_supported ? 64ull : 0ull)
|
||||
| (warp_specialization ? 32ull : 0ull) | (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull)
|
||||
| (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||
}
|
||||
|
||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
||||
{
|
||||
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
||||
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
|
||||
kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
|
||||
kernelMeta.mAlibiSupported, kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
|
||||
}
|
||||
|
||||
virtual void run(
|
||||
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
|
||||
{
|
||||
|
||||
const auto findIter = mFunctions.find(
|
||||
hashID(launch_params.kernel_s, params.d, launch_params.interleaved, launch_params.force_unroll,
|
||||
launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.warp_specialization,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "FMHA kernels are not found");
|
||||
const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
|
||||
launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention,
|
||||
launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||
|
||||
// Add debug info when kernels are not found.
|
||||
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(),
|
||||
"Paged KV FMHA kernels are not found (kernel meta info: %d %d %d %d %d %d %d %d %d %d) !",
|
||||
launch_params.kernel_s, params.d, launch_params.interleaved, launch_params.force_unroll,
|
||||
launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.warp_specialization,
|
||||
!launch_params.useKernelWithoutAlibi, static_cast<int>(launch_params.attention_mask_type),
|
||||
launch_params.granular_tiling);
|
||||
|
||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -142,9 +142,17 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
{
|
||||
case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
|
||||
case CutlassGemmType::WeightOnly:
|
||||
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
||||
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
|
||||
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
|
||||
if (sm >= 75)
|
||||
{
|
||||
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
||||
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
|
||||
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
||||
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64};
|
||||
}
|
||||
case CutlassGemmType::Int8:
|
||||
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
||||
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
|
||||
@ -156,8 +164,8 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
|
||||
const bool int8_configs_only, const int max_split_k)
|
||||
{
|
||||
std::vector<CutlassTileConfig> tiles
|
||||
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||
@ -165,13 +173,20 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
const int min_stages = int8_configs_only ? 3 : 2;
|
||||
const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
|
||||
for (const auto& tile_config : tiles)
|
||||
{
|
||||
for (int stages = min_stages; stages <= max_stages; ++stages)
|
||||
{
|
||||
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
|
||||
candidate_configs.push_back(config);
|
||||
if (sm >= 75)
|
||||
{
|
||||
for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor)
|
||||
{
|
||||
auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages};
|
||||
candidate_configs.push_back(config);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -26,8 +26,9 @@ namespace kernels
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false);
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
|
||||
const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false,
|
||||
const int max_split_k = 1);
|
||||
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
|
||||
const std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig>& candidate_configs,
|
||||
|
||||
@ -75,7 +75,7 @@ public:
|
||||
protected:
|
||||
static constexpr int SPLIT_K_LIMIT = 7;
|
||||
static constexpr int MIN_M_TILE = 32;
|
||||
static constexpr int MIN_N_TILE = 128;
|
||||
static constexpr int MIN_N_TILE = 64;
|
||||
};
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
|
||||
@ -466,7 +466,8 @@ template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
std::vector<tkc::CutlassGemmConfig> CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::getConfigs() const
|
||||
{
|
||||
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs = get_candidate_configs(sm_, is_weight_only, false);
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs
|
||||
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT);
|
||||
return candidateConfigs;
|
||||
}
|
||||
|
||||
|
||||
@ -368,7 +368,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassInt8GemmRunner<T>::getConfigs() const
|
||||
static constexpr bool isWeightOnly = false;
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs
|
||||
= get_candidate_configs(mSm, isWeightOnly, mSm <= 70, /* SIMT configs */
|
||||
true); /* INT8 configs */
|
||||
true, SPLIT_K_LIMIT); /* INT8 configs */
|
||||
return candidateConfigs;
|
||||
}
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ namespace tensorrt_llm
|
||||
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape, int Stages>
|
||||
void genericMoeGemmKernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream,
|
||||
int* kernel_occupancy = nullptr)
|
||||
{
|
||||
@ -152,11 +152,11 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
|
||||
struct dispatch_stages
|
||||
{
|
||||
static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_THROW("Cutlass fpA_intB gemm. Not instantiates for arch %d with stages set to %d",
|
||||
TLLM_THROW("Cutlass fpA_intB gemm. Not instantiated for arch %d with stages set to %d",
|
||||
arch::kMinComputeCapability, Stages);
|
||||
}
|
||||
};
|
||||
@ -166,12 +166,12 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
|
||||
struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>
|
||||
{
|
||||
static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
genericMoeGemmKernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
|
||||
weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, stream, occupancy);
|
||||
}
|
||||
};
|
||||
@ -182,20 +182,20 @@ struct dispatch_stages<T, WeightType, cutlass::arch::Sm80, EpilogueTag, Threadbl
|
||||
typename std::enable_if<(Stages > 2)>::type>
|
||||
{
|
||||
static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
genericMoeGemmKernelLauncher<T, WeightType, cutlass::arch::Sm80, EpilogueTag, ThreadblockShape, WarpShape,
|
||||
Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, stream, occupancy);
|
||||
Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, stream, occupancy);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
@ -203,17 +203,17 @@ void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales,
|
||||
{
|
||||
case 2:
|
||||
using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
|
||||
DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k,
|
||||
DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case 3:
|
||||
using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
|
||||
DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k,
|
||||
DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case 4:
|
||||
using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
|
||||
DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k,
|
||||
DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break;
|
||||
@ -233,18 +233,18 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
|
||||
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
|
||||
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
|
||||
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
@ -268,18 +268,18 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
|
||||
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
|
||||
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
|
||||
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
@ -301,8 +301,8 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 8>,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n,
|
||||
gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
@ -356,6 +356,13 @@ void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(const T* A, const
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
|
||||
stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 90)
|
||||
{
|
||||
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
|
||||
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
|
||||
stream, occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Arch unsupported for MoE GEMM");
|
||||
|
||||
@ -17,10 +17,12 @@
|
||||
|
||||
#include "decoderMaskedMultiheadAttentionTemplate.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cuda_runtime_api.h>
|
||||
#ifdef ENABLE_FP8
|
||||
#include <cuda_fp8.h>
|
||||
@ -91,24 +93,35 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_AT
|
||||
|
||||
template <typename T, int Dh, bool DO_CROSS_ATTENTION>
|
||||
inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
|
||||
int blocks_per_sm, int block_size, int tlength, bool do_multi_block)
|
||||
int blocks_per_sm, int block_size, int tlength)
|
||||
{
|
||||
if (!do_multi_block)
|
||||
if (!params.multi_block_mode)
|
||||
{
|
||||
params.multi_block_mode = false;
|
||||
return;
|
||||
}
|
||||
|
||||
params.seq_len_tile
|
||||
int balanced_seq_len_tile
|
||||
= mmha::divUp(params.multi_processor_count * blocks_per_sm, params.batch_size * params.num_heads);
|
||||
params.seq_len_tile = std::max(params.min_seq_len_tile, params.seq_len_tile);
|
||||
|
||||
const int threads_per_value = mmha::threads_per_value<T>(mmha::dh_max(Dh));
|
||||
// Make sure that each block at least processes one loop of kv (unroll size is default at 8).
|
||||
const int seq_len_per_kv_loop = mmha::divUp(block_size, threads_per_value) * 8;
|
||||
const int max_seq_len_tile = std::min(mmha::divUp(tlength + 1, seq_len_per_kv_loop), params.max_seq_len_tile);
|
||||
int max_seq_len_tile = params.max_seq_len_tile;
|
||||
|
||||
params.seq_len_tile = std::min(params.seq_len_tile, max_seq_len_tile);
|
||||
const bool multi_block_debug_flag = getEnvMmhaMultiblockDebug();
|
||||
|
||||
// User defined number of blocks.
|
||||
if (multi_block_debug_flag)
|
||||
{
|
||||
const int env_seq_len_tile = getEnvMmhaBlocksPerSequence();
|
||||
balanced_seq_len_tile = env_seq_len_tile > 0 ? env_seq_len_tile : balanced_seq_len_tile;
|
||||
}
|
||||
else
|
||||
{
|
||||
max_seq_len_tile = std::min(mmha::divUp(tlength + 1, seq_len_per_kv_loop), max_seq_len_tile);
|
||||
}
|
||||
|
||||
params.seq_len_tile = std::clamp(balanced_seq_len_tile, params.min_seq_len_tile, max_seq_len_tile);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
params.seq_len_tile <= block_size, "The number of blocks per sequence may not exceed the thread block size.");
|
||||
@ -118,6 +131,14 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
|
||||
|
||||
params.multi_block_mode = (params.seq_len_tile > 1);
|
||||
|
||||
static bool debug_flag_printed_once = false;
|
||||
if (multi_block_debug_flag && !debug_flag_printed_once)
|
||||
{
|
||||
TLLM_LOG_INFO("MMHA kernel info: threads per block(%d), launched_blocks_per_sequence(%d), sequence_length(%d).",
|
||||
block_size, params.seq_len_tile, tlength + 1);
|
||||
debug_flag_printed_once = true;
|
||||
}
|
||||
|
||||
grid.z = params.seq_len_tile;
|
||||
}
|
||||
|
||||
@ -230,7 +251,7 @@ void mmha_launch_kernel_ex(
|
||||
}
|
||||
|
||||
// If blocks with larger block size already fill all SMs, then disable the multi blocks mode.
|
||||
mmha::multi_block_grid_setup<T, Dh>(grid, params, available_blocks, dynamic_block_size, tlength, DO_MULTI_BLOCK);
|
||||
mmha::multi_block_grid_setup<T, Dh>(grid, params, available_blocks, dynamic_block_size, tlength);
|
||||
|
||||
// Launch kernels based on the valid block size.
|
||||
switch (dynamic_block_size)
|
||||
|
||||
@ -2549,7 +2549,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
float final_max = -FLT_MAX;
|
||||
float thread_partial_max = -FLT_MAX;
|
||||
thread_partial_max = params.partial_max[bhi_seq_len_tile + min(tidx, gridDim.x - 1)];
|
||||
thread_partial_max = params.partial_max[bhi_seq_len_tile + min(tidx, gridDim.z - 1)];
|
||||
|
||||
// Make sure we can start writing to shared memory.
|
||||
__syncthreads();
|
||||
|
||||
@ -33,14 +33,15 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2>
|
||||
void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2, typename cutlassType>
|
||||
void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
|
||||
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
|
||||
int64_t cublasWorkspaceSize, cudaStream_t stream)
|
||||
int64_t cublasWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream)
|
||||
{
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
using ElementA = cutlassType;
|
||||
using ElementB = cutlassType;
|
||||
using ElementOutput = cutlassType;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
@ -71,9 +72,10 @@ void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
|
||||
float beta = 0.0f;
|
||||
typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha, beta);
|
||||
|
||||
auto gemm_coord_size = tensorrt_llm::common::divUp(problem_count * sizeof(cutlass::gemm::GemmCoord), 16) * 16;
|
||||
auto ptr_size = tensorrt_llm::common::divUp(problem_count * sizeof(half*), 16) * 16;
|
||||
auto ldd_size = tensorrt_llm::common::divUp(problem_count * sizeof(int64_t), 16) * 16;
|
||||
auto gemm_coord_size
|
||||
= tensorrt_llm::common::divUp((size_t) problem_count * sizeof(cutlass::gemm::GemmCoord), (size_t) 16) * 16;
|
||||
auto ptr_size = tensorrt_llm::common::divUp((size_t) problem_count * sizeof(half*), (size_t) 16) * 16;
|
||||
auto ldd_size = tensorrt_llm::common::divUp((size_t) problem_count * sizeof(int64_t), (size_t) 16) * 16;
|
||||
|
||||
char* host_workspace = (char*) std::malloc(workSpaceSize);
|
||||
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_workspace);
|
||||
@ -135,24 +137,46 @@ void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
std::free(host_workspace);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2>
|
||||
void groupedGemmType_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream)
|
||||
{
|
||||
// For lora in, which has smaller N
|
||||
run_cutlass_<128, 32, 32, 32, 32, 32>(
|
||||
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
|
||||
if (dataType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
groupedGemm_<M1, N1, K1, M2, N2, K2, cutlass::half_t>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace,
|
||||
workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
|
||||
}
|
||||
else if (dataType == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "not support float input/output");
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (dataType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
groupedGemm_<M1, N1, K1, M2, N2, K2, cutlass::bfloat16_t>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace,
|
||||
workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
|
||||
void gropuedGemm(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
|
||||
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
|
||||
int64_t cublasWorkspaceSize, bool isLoraIn, nvinfer1::DataType dataType, cudaStream_t stream)
|
||||
{
|
||||
// For lora out, which has larger N
|
||||
run_cutlass_<128, 128, 32, 64, 64, 32>(
|
||||
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
|
||||
if (isLoraIn)
|
||||
{
|
||||
groupedGemmType_<16, 32, 64, 16, 32, 64>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize,
|
||||
cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
groupedGemmType_<32, 128, 32, 32, 32, 32>(problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize,
|
||||
cublasWorkSpace, cublasWorkspaceSize, dataType, stream);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -16,19 +16,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
|
||||
|
||||
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
|
||||
void gropuedGemm(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
|
||||
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
|
||||
int64_t cublasWorkspaceSize, bool isLoraIn, nvinfer1::DataType dataType, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
@ -58,8 +59,9 @@ static constexpr int WARP_SIZE = 32;
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing the output
|
||||
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeSoftmax(const T* input, const bool* finished, T* output, const int num_cols)
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
||||
{
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
@ -81,7 +83,7 @@ __launch_bounds__(TPB) __global__ void moeSoftmax(const T* input, const bool* fi
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
threadData = max(input[idx], threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
@ -111,16 +113,16 @@ __launch_bounds__(TPB) __global__ void moeSoftmax(const T* input, const bool* fi
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
output[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const T* inputs_after_softmax, const bool* finished, T* output,
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
@ -135,7 +137,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(const T* inputs_after_softmax, co
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
||||
@ -189,9 +191,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const T* inputs_after_softmax, co
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(const T* input, const bool* finished, T* output, const int num_rows, int* indices,
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
@ -201,7 +203,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
@ -243,20 +245,20 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||
// row it will read.
|
||||
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
||||
// this can support all powers of 2 up to 16.
|
||||
using AccessType = cutlass::AlignedArray<T, ELTS_PER_LDG>;
|
||||
using AccessType = cutlass::AlignedArray<float, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
cutlass::Array<T, VPT> row_chunk_input;
|
||||
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_input);
|
||||
cutlass::Array<float, VPT> row_chunk;
|
||||
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
||||
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
||||
@ -264,14 +266,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
using ComputeType = float;
|
||||
using Converter = cutlass::NumericArrayConverter<ComputeType, T, VPT>;
|
||||
Converter compute_type_converter;
|
||||
cutlass::Array<ComputeType, VPT> row_chunk = compute_type_converter(row_chunk_input);
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||
// convert to float afterwards for the exp + sum reduction.
|
||||
ComputeType thread_max = row_chunk[0];
|
||||
float thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii)
|
||||
{
|
||||
@ -370,7 +367,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||
// single) thread per row of the input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = T(max_val);
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
@ -386,7 +383,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
{
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f);
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -395,10 +392,10 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <typename T, int EXPERTS, int BYTES_PER_LDG>
|
||||
template <int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
@ -407,28 +404,27 @@ struct TopkConstants
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const T* input, const bool* finished, T* output, int* indices, int* source_row,
|
||||
const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void topkGatingSoftmaxKernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output,
|
||||
int* indices, int* source_row, const int num_rows, const int num_experts, const int k, const int start_expert,
|
||||
const int end_expert, cudaStream_t stream)
|
||||
void topkGatingSoftmaxKernelLauncher(const float* input, const bool* finished, float* output,
|
||||
float* softmax_temp_output, int* indices, int* source_row, const int num_rows, const int num_experts, const int k,
|
||||
const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
@ -436,55 +432,55 @@ void topkGatingSoftmaxKernelLauncher(const T* input, const bool* finished, T* ou
|
||||
{
|
||||
case 1:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 1, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<1, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 2, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<2, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 4, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<4, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 8:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 8, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<8, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 16:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 16, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<16, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 32:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 32, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<32, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 64:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 64, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<64, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 128:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 128, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<128, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
case 256:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<T, 256, WARPS_PER_TB>(
|
||||
topkGatingSoftmaxLauncherHelper<256, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream);
|
||||
break;
|
||||
}
|
||||
@ -492,8 +488,8 @@ void topkGatingSoftmaxKernelLauncher(const T* input, const bool* finished, T* ou
|
||||
{
|
||||
static constexpr int TPB = 256;
|
||||
TLLM_CHECK(softmax_temp_output != nullptr);
|
||||
moeSoftmax<T, TPB><<<num_rows, TPB, 0, stream>>>(input, finished, softmax_temp_output, num_experts);
|
||||
moeTopK<T, TPB><<<num_rows, TPB, 0, stream>>>(
|
||||
moeSoftmax<TPB><<<num_rows, TPB, 0, stream>>>(input, finished, softmax_temp_output, num_experts);
|
||||
moeTopK<TPB><<<num_rows, TPB, 0, stream>>>(
|
||||
softmax_temp_output, finished, output, indices, source_row, num_experts, k, start_expert, end_expert);
|
||||
}
|
||||
}
|
||||
@ -649,7 +645,7 @@ enum class ScaleMode : int
|
||||
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection.
|
||||
template <typename T, int RESIDUAL_NUM, bool HAS_BIAS, ScaleMode SCALE_MODE, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1,
|
||||
const T* skip_2, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row,
|
||||
const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row, const int cols, const int k, const int64_t* num_valid_ptr)
|
||||
{
|
||||
const int original_row = blockIdx.x;
|
||||
@ -672,14 +668,14 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red
|
||||
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x)
|
||||
{
|
||||
T thread_output{0.f};
|
||||
T row_rescale{0.f};
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
const int expanded_original_row = original_row + k_idx * num_rows;
|
||||
const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
const int64_t k_offset = original_row * k + k_idx;
|
||||
const T row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? T(1.f) : scales[k_offset];
|
||||
const float row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
|
||||
if constexpr (SCALE_MODE == ScaleMode::RENORM_SCALE)
|
||||
{
|
||||
row_rescale = row_rescale + row_scale;
|
||||
@ -698,13 +694,14 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red
|
||||
const T* bias_ptr = bias + expert_idx * cols;
|
||||
const T bias_value = HAS_BIAS ? bias_ptr[tid] : T(0.f);
|
||||
|
||||
thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_value);
|
||||
thread_output = static_cast<float>(thread_output)
|
||||
+ row_scale * static_cast<float>(expanded_permuted_rows_row_ptr[tid] + bias_value);
|
||||
}
|
||||
|
||||
if (SCALE_MODE == ScaleMode::RENORM_SCALE && (!CHECK_SKIPPED || thread_output))
|
||||
{
|
||||
assert(row_rescale != T(0.f));
|
||||
thread_output = thread_output / row_rescale;
|
||||
assert(row_rescale != 0.f);
|
||||
thread_output = static_cast<float>(thread_output) / row_rescale;
|
||||
}
|
||||
|
||||
if (RESIDUAL_NUM == 1)
|
||||
@ -721,7 +718,7 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red
|
||||
|
||||
template <typename T, int RESIDUAL_NUM>
|
||||
void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, T* reduced_unpermuted_output,
|
||||
const T* skip_1, const T* skip_2, const T* bias, const T* scales,
|
||||
const T* skip_1, const T* skip_2, const T* bias, const float* scales,
|
||||
const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int num_rows,
|
||||
const int cols, const int k, const int64_t* num_valid_ptr, MOEParallelismConfig parallelism_config,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
|
||||
@ -766,7 +763,7 @@ void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows,
|
||||
|
||||
template <typename T>
|
||||
void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1,
|
||||
const T* skip_2, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row,
|
||||
const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row, const int num_rows, const int cols, const int k, const int64_t* num_valid_ptr,
|
||||
MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
|
||||
{
|
||||
@ -834,6 +831,47 @@ void doGatedActivation(T* output, const T* gemm_result, const int64_t* num_valid
|
||||
fn<<<blocks, threads, 0, stream>>>(output, gemm_result, num_valid_tokens_ptr, inter_size);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename Enable>
|
||||
std::vector<size_t> CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceBufferSizes(const int num_rows,
|
||||
const int hidden_size, const int inter_size, const int num_experts, const int num_experts_per_node, const int k,
|
||||
ActivationType activation_type) const
|
||||
{
|
||||
const size_t num_moe_inputs = k * num_rows;
|
||||
const size_t buf_size = num_moe_inputs * hidden_size;
|
||||
const size_t interbuf_elems = num_moe_inputs * inter_size;
|
||||
const size_t glu_inter_elems = isGatedActivation(activation_type) ? (interbuf_elems * 2) : 0;
|
||||
int num_softmax_outs = 0;
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256)
|
||||
{
|
||||
num_softmax_outs = num_rows * num_experts;
|
||||
}
|
||||
|
||||
size_t source_rows_size = num_moe_inputs * sizeof(int);
|
||||
size_t permuted_rows_size = num_moe_inputs * sizeof(int);
|
||||
size_t permuted_experts_size = num_moe_inputs * sizeof(int);
|
||||
size_t permuted_data_size = buf_size * sizeof(T);
|
||||
size_t total_rows_before_expert_size = num_experts_per_node * sizeof(int64_t);
|
||||
size_t softmax_out_size = num_softmax_outs * sizeof(float);
|
||||
size_t glu_inter_size = glu_inter_elems * sizeof(T);
|
||||
size_t fc1_result_size = interbuf_elems * sizeof(T);
|
||||
size_t sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts);
|
||||
|
||||
std::vector<size_t> workspace{
|
||||
source_rows_size,
|
||||
permuted_rows_size,
|
||||
permuted_experts_size,
|
||||
permuted_data_size,
|
||||
total_rows_before_expert_size,
|
||||
softmax_out_size,
|
||||
glu_inter_size,
|
||||
// These pointers reuse the same memory
|
||||
std::max(fc1_result_size, sorter_size),
|
||||
};
|
||||
return workspace;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename Enable>
|
||||
size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(const int num_rows, const int hidden_size,
|
||||
const int inter_size, const int num_experts, const int k, ActivationType activation_type,
|
||||
@ -841,38 +879,9 @@ size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(const int num
|
||||
{
|
||||
const int ep_size = parallelism_config.ep_size;
|
||||
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size");
|
||||
const int buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
|
||||
const int glu_inter_size
|
||||
= isGatedActivation(activation_type) ? pad_to_multiple_of_16(k * num_rows * inter_size * 2) : 0;
|
||||
const int interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size) + glu_inter_size;
|
||||
const int padded_experts = pad_to_multiple_of_16(num_experts / ep_size);
|
||||
const int num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
|
||||
int num_softmax_outs = 0;
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256)
|
||||
{
|
||||
num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them
|
||||
// in Encoder or Decoder before invoking FfnLayer forward.
|
||||
size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(T); // permuted_data
|
||||
total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
|
||||
total_ws_bytes += num_softmax_outs * sizeof(T);
|
||||
const int bytes_for_fc1_result = interbuf_size * sizeof(T);
|
||||
const int sorter_ws_size_bytes = pad_to_multiple_of_16(CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts));
|
||||
|
||||
int bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result)
|
||||
{
|
||||
int remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace
|
||||
return total_ws_bytes;
|
||||
auto workspace = getWorkspaceBufferSizes(
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type);
|
||||
return tensorrt_llm::common::calculateTotalWorkspaceSize(workspace.data(), workspace.size());
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename Enable>
|
||||
@ -880,45 +889,43 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::configureWsPtrs(char* ws_ptr, co
|
||||
const int inter_size, const int num_experts, const int num_experts_per_node, const int k,
|
||||
ActivationType activation_type)
|
||||
{
|
||||
const int buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
|
||||
const int glu_inter_size
|
||||
= isGatedActivation(activation_type) ? pad_to_multiple_of_16(k * num_rows * inter_size * 2) : 0;
|
||||
const int interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size) + glu_inter_size;
|
||||
const int sorter_ws_size_bytes = pad_to_multiple_of_16(CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts));
|
||||
const size_t interbuf_sorter_size_bytes = std::max(interbuf_size * sizeof(T), (size_t) sorter_ws_size_bytes);
|
||||
const int padded_experts = pad_to_multiple_of_16(num_experts_per_node);
|
||||
const int num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
|
||||
auto workspace = getWorkspaceBufferSizes(
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, activation_type);
|
||||
|
||||
source_rows_ = (int*) ws_ptr;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
permuted_data_ = (T*) (permuted_experts_ + num_moe_inputs);
|
||||
std::vector<int8_t*> ws_sliced{(int8_t*) ws_ptr};
|
||||
for (auto size : workspace)
|
||||
{
|
||||
ws_sliced.push_back(nextWorkspacePtr(ws_sliced.back(), size));
|
||||
}
|
||||
|
||||
total_rows_before_expert_ = (int64_t*) (permuted_data_ + buf_size);
|
||||
source_rows_ = (int*) ws_sliced[0];
|
||||
permuted_rows_ = (int*) ws_sliced[1];
|
||||
permuted_experts_ = (int*) ws_sliced[2];
|
||||
permuted_data_ = (T*) ws_sliced[3];
|
||||
|
||||
// These pointers are aliased. Since the sort ws can be overwritten after it is finished
|
||||
glu_inter_result_ = (T*) (total_rows_before_expert_ + padded_experts);
|
||||
fc1_result_ = glu_inter_result_ + glu_inter_size;
|
||||
sorter_ws_ = (char*) glu_inter_result_;
|
||||
total_rows_before_expert_ = (int64_t*) ws_sliced[4];
|
||||
|
||||
softmax_out_ = nullptr;
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256)
|
||||
{
|
||||
softmax_out_ = (T*) (sorter_ws_ + interbuf_sorter_size_bytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
softmax_out_ = nullptr;
|
||||
softmax_out_ = (float*) ws_sliced[5];
|
||||
}
|
||||
|
||||
glu_inter_result_ = (T*) ws_sliced[6];
|
||||
|
||||
// These pointers are aliased. Since the sort ws can be overwritten after it is finished
|
||||
sorter_ws_ = (char*) ws_sliced[7];
|
||||
fc1_result_ = (T*) ws_sliced[7];
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename Enable>
|
||||
void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activations_void,
|
||||
const void* gating_output_void, const void* fc1_expert_weights_void, const void* fc1_scales_void,
|
||||
const void* fc1_expert_biases_void, ActivationType fc1_activation_type, const void* fc2_expert_weights_void,
|
||||
const void* fc2_scales_void, const void* fc2_expert_biases_void, const int num_rows, const int hidden_size,
|
||||
const int inter_size, const int num_experts, const int k, char* workspace_ptr, void* final_output_void,
|
||||
void* fc2_result_void, const bool* finished, const int active_rows, void* expert_scales_void,
|
||||
void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activations_void, const float* gating_output,
|
||||
const void* fc1_expert_weights_void, const void* fc1_scales_void, const void* fc1_expert_biases_void,
|
||||
ActivationType fc1_activation_type, const void* fc2_expert_weights_void, const void* fc2_scales_void,
|
||||
const void* fc2_expert_biases_void, const int num_rows, const int hidden_size, const int inter_size,
|
||||
const int num_experts, const int k, char* workspace_ptr, void* final_output_void, void* fc2_result_void,
|
||||
const bool* finished, const int active_rows, void* expert_scales_void,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, MOEParallelismConfig parallelism_config,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
|
||||
{
|
||||
@ -926,7 +933,6 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
|
||||
= std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value;
|
||||
|
||||
auto* input_activations = static_cast<const T*>(input_activations_void);
|
||||
auto* gating_output = static_cast<const T*>(gating_output_void);
|
||||
auto* fc1_expert_weights = static_cast<const WeightType*>(fc1_expert_weights_void);
|
||||
auto* fc1_scales = static_cast<const T*>(fc1_scales_void);
|
||||
auto* fc1_expert_biases = static_cast<const T*>(fc1_expert_biases_void);
|
||||
@ -935,7 +941,7 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
|
||||
auto* fc2_expert_biases = static_cast<const T*>(fc2_expert_biases_void);
|
||||
auto* final_output = static_cast<T*>(final_output_void);
|
||||
auto* fc2_result = static_cast<T*>(fc2_result_void);
|
||||
auto* expert_scales = static_cast<T*>(expert_scales_void);
|
||||
auto* expert_scales = static_cast<float*>(expert_scales_void);
|
||||
|
||||
TLLM_CHECK(input_activations);
|
||||
TLLM_CHECK(gating_output);
|
||||
@ -965,7 +971,7 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
|
||||
|
||||
configureWsPtrs(
|
||||
workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, fc1_activation_type);
|
||||
topkGatingSoftmaxKernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
|
||||
topkGatingSoftmaxKernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
|
||||
source_rows_, num_rows, num_experts, k, start_expert, end_expert, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
@ -1008,8 +1014,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::runMoe(const void* input_activat
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
doGatedActivation<T>(
|
||||
fc1_result_, glu_inter_result_, num_valid_tokens_ptr, inter_size, num_rows, fc1_activation_type, stream);
|
||||
doGatedActivation<T>(fc1_result_, glu_inter_result_, num_valid_tokens_ptr, inter_size, num_rows * k,
|
||||
fc1_activation_type, stream);
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -137,7 +137,7 @@ public:
|
||||
virtual void setTactic(std::optional<cutlass_extensions::CutlassGemmConfig> gemm_config) = 0;
|
||||
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() = 0;
|
||||
|
||||
virtual void runMoe(const void* input_activations, const void* gating_output, const void* fc1_expert_weights,
|
||||
virtual void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights,
|
||||
const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type,
|
||||
const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows,
|
||||
const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr,
|
||||
@ -173,7 +173,7 @@ public:
|
||||
return moe_gemm_runner_.getConfigs();
|
||||
}
|
||||
|
||||
void runMoe(const void* input_activations, const void* gating_output, const void* fc1_expert_weights,
|
||||
void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights,
|
||||
const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type,
|
||||
const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows,
|
||||
const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr,
|
||||
@ -185,6 +185,8 @@ public:
|
||||
private:
|
||||
void computeTotalRowsBeforeExpert(const int* sorted_indices, const int total_indices, const int num_experts,
|
||||
int64_t* total_rows_before_expert, cudaStream_t stream);
|
||||
std::vector<size_t> getWorkspaceBufferSizes(const int num_rows, const int hidden_size, const int inter_size,
|
||||
const int num_experts, const int num_experts_per_node, const int k, ActivationType activation_type) const;
|
||||
void configureWsPtrs(char* ws_ptr, const int num_rows, const int hidden_size, const int inter_size,
|
||||
const int num_experts, const int num_experts_per_node, const int k, ActivationType activation_type);
|
||||
|
||||
@ -198,7 +200,7 @@ private:
|
||||
int* permuted_experts_;
|
||||
char* sorter_ws_;
|
||||
T* permuted_data_;
|
||||
T* softmax_out_;
|
||||
float* softmax_out_;
|
||||
|
||||
int64_t* total_rows_before_expert_;
|
||||
|
||||
@ -224,7 +226,7 @@ public:
|
||||
return;
|
||||
}
|
||||
|
||||
void runMoe(const void* input_activations, const void* gating_output, const void* fc1_expert_weights,
|
||||
void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights,
|
||||
const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type,
|
||||
const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows,
|
||||
const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr,
|
||||
|
||||
@ -209,8 +209,11 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const float* penalties, c
|
||||
{
|
||||
// outputIds shape: (batchSize, input_len + output_len)
|
||||
int penaltyIndex = outputIds[batchIdx][blockIdx.y * maxSeqLen + index];
|
||||
assert(penaltyIndex < vocabSize);
|
||||
penaltyIndices[index] = penaltyIndex;
|
||||
if (penaltyIndex >= vocabSize)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
float logit = (float) logits[penaltyIndex];
|
||||
if (penaltyType == RepetitionPenaltyType::Additive)
|
||||
{
|
||||
@ -239,6 +242,10 @@ __global__ void batchApplyRepetitionPenalty(T* logits, const float* penalties, c
|
||||
// Phase 2. Replace a logit value by the penalized one.
|
||||
for (int index = threadIdx.x; index < currentStep; index += blockDim.x)
|
||||
{
|
||||
if (penaltyIndices[index] >= vocabSize)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
logits[penaltyIndices[index]] = penaltyLogits[index];
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,10 +37,14 @@ __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_i
|
||||
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int bb_id = threadIdx.y + blockIdx.y * blockDim.y; // should be just blockIdx.y?
|
||||
const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1
|
||||
const int input_length{input_lengths == nullptr ? 0 : input_lengths[bb_id]};
|
||||
const int batch_id = bb_id / beam_width;
|
||||
const int beam_id = bb_id % beam_width;
|
||||
if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_attention_window)
|
||||
|| finished[bb_id].isFinished())
|
||||
// Exit when the batch_beam or timestep is out of the bound.
|
||||
// Assume that KV Cache is shared and fixed for context part,
|
||||
// so we don't need to update the indices for context part.
|
||||
if (bb_id >= beam_width * local_batch_size || time_step >= max_seq_len || time_step < input_length
|
||||
|| time_step < (max_seq_len - max_attention_window) || finished[bb_id].isFinished())
|
||||
{
|
||||
return;
|
||||
}
|
||||
@ -90,14 +94,14 @@ BaseBeamSearchLayer<T>::BaseBeamSearchLayer(BaseBeamSearchLayer<T> const& beam_s
|
||||
template <typename T>
|
||||
BaseBeamSearchLayer<T>::~BaseBeamSearchLayer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
freeBuffer();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseBeamSearchLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (is_allocate_buffer_)
|
||||
{
|
||||
allocator_->free((void**) (&temperature_buf_));
|
||||
@ -105,26 +109,26 @@ void BaseBeamSearchLayer<T>::freeBuffer()
|
||||
allocator_->free((void**) (&repetition_penalty_buf_));
|
||||
is_allocate_buffer_ = false;
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
temperature_buf_ = allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false);
|
||||
min_lengths_buf_ = allocator_->reMalloc(min_lengths_buf_, sizeof(int) * batch_size, false);
|
||||
repetition_penalty_buf_ = allocator_->reMalloc(repetition_penalty_buf_, sizeof(float) * batch_size, false);
|
||||
|
||||
is_allocate_buffer_ = true;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseBeamSearchLayer<T>::setupBase(size_t batch_size, SetupParams const& setupParams)
|
||||
{
|
||||
allocateBuffer(batch_size);
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
// Setup penalties.
|
||||
FillBuffers const fillBuffers{batch_size, stream_};
|
||||
|
||||
@ -149,13 +153,13 @@ void BaseBeamSearchLayer<T>::setupBase(size_t batch_size, SetupParams const& set
|
||||
fillBuffers(setupParams.presence_penalty, 1.0f, mRepetitionPenalty, repetition_penalty_buf_);
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardParams const& params)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
|
||||
Tensor& output_ids_ptr = outputs.output_ids_ptr;
|
||||
|
||||
const auto batch_size = static_cast<std::int32_t>(output_ids_ptr.shape[0]);
|
||||
|
||||
@ -33,7 +33,7 @@ namespace layers
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
curandstate_buf_ = allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false);
|
||||
random_seeds_buf_ = allocator_->reMalloc(random_seeds_buf_, sizeof(unsigned long long) * batch_size, false);
|
||||
temperature_buf_ = allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false);
|
||||
@ -52,7 +52,7 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size)
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
if (is_allocate_buffer_)
|
||||
{
|
||||
allocator_->free((void**) (&curandstate_buf_));
|
||||
@ -93,7 +93,7 @@ BaseSamplingLayer<T>::~BaseSamplingLayer()
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::setupBase(const size_t batch_size, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
allocateBuffer(batch_size);
|
||||
|
||||
// If runtime argument has single random seed, using this random seed to
|
||||
@ -170,7 +170,7 @@ void BaseSamplingLayer<T>::setupBase(const size_t batch_size, SetupParams const&
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams const& params)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batch_size = outputs.output_ids_ptr.shape[0];
|
||||
auto const local_batch_size = params.logits.shape[0];
|
||||
@ -234,7 +234,7 @@ void BaseSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams
|
||||
freeBuffer();
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template class BaseSamplingLayer<float>;
|
||||
|
||||
@ -37,7 +37,7 @@ namespace layers
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::initialize()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mOnlineBeamsearchDecode = std::make_unique<OnlineBeamSearchLayer<T>>(
|
||||
vocab_size_, vocab_size_padded_, stream_, allocator_, is_free_buffer_after_forward_);
|
||||
|
||||
@ -58,14 +58,14 @@ DynamicDecodeLayer<T>::DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_p
|
||||
, vocab_size_padded_(vocab_size_padded)
|
||||
, cuda_device_prop_(cuda_device_prop)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
initialize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DynamicDecodeLayer<T>::~DynamicDecodeLayer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
freeBuffer();
|
||||
}
|
||||
|
||||
@ -76,7 +76,7 @@ DynamicDecodeLayer<T>::DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_deco
|
||||
, vocab_size_padded_(dynamic_decode_layer.vocab_size_padded_)
|
||||
, cuda_device_prop_(dynamic_decode_layer.cuda_device_prop_)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
initialize();
|
||||
}
|
||||
|
||||
@ -117,7 +117,7 @@ bool hasDiffRuntimeArgs(DecodingSetupParams const& params)
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::setup(size_t batch_size, size_t beam_width, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
if (beam_width == 1)
|
||||
{ // sampling layers
|
||||
@ -159,7 +159,7 @@ void DynamicDecodeLayer<T>::setup(size_t batch_size, size_t beam_width, SetupPar
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mIdsPtrHost->resize(2 * batch_size);
|
||||
zero_parent_ids = allocator_->reMalloc(zero_parent_ids, sizeof(int*) * 2 * batch_size, false);
|
||||
}
|
||||
@ -167,14 +167,14 @@ void DynamicDecodeLayer<T>::allocateBuffer(size_t batch_size, size_t beam_width,
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
allocator_->free((void**) &zero_parent_ids);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const& params)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
const auto ite = params.ite;
|
||||
const auto step = params.step;
|
||||
|
||||
@ -97,7 +97,7 @@ void invokeUpdate(FinishedState* finished, int** parent_ids_ptr, int* sequence_l
|
||||
template <typename T>
|
||||
void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
BaseBeamSearchLayer<T>::setupBase(batch_size, setupParams);
|
||||
allocateBuffer(batch_size);
|
||||
|
||||
@ -108,13 +108,13 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
|
||||
|
||||
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_);
|
||||
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
|
||||
Tensor const& output_ids_ptr = outputs.output_ids_ptr;
|
||||
const auto batch_size = static_cast<std::int32_t>(output_ids_ptr.shape[0]);
|
||||
const auto beam_width = static_cast<std::int32_t>(output_ids_ptr.shape[1]);
|
||||
@ -159,7 +159,7 @@ void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, So
|
||||
template <typename T>
|
||||
void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
// we need to check 2 * beam_width candidates each time
|
||||
// 64 is the max beam width we support now.
|
||||
topk_softmax_workspace_size_ = (size_t) (ceil(batch_size * 64 * (64 * 2) / 4.) * 4 * 2
|
||||
@ -171,13 +171,13 @@ void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
|
||||
length_penalties_buf_ = allocator_->reMalloc(length_penalties_buf_, sizeof(float) * batch_size, false);
|
||||
|
||||
is_allocate_buffer_ = true;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void OnlineBeamSearchLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (is_allocate_buffer_)
|
||||
{
|
||||
allocator_->free((void**) (&topk_softmax_workspace_));
|
||||
@ -185,7 +185,7 @@ void OnlineBeamSearchLayer<T>::freeBuffer()
|
||||
allocator_->free((void**) (&length_penalties_buf_));
|
||||
is_allocate_buffer_ = false;
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -199,13 +199,13 @@ template <typename T>
|
||||
OnlineBeamSearchLayer<T>::OnlineBeamSearchLayer(OnlineBeamSearchLayer<T> const& beam_search_layer)
|
||||
: BaseBeamSearchLayer<T>(beam_search_layer)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OnlineBeamSearchLayer<T>::~OnlineBeamSearchLayer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template class OnlineBeamSearchLayer<float>;
|
||||
|
||||
@ -85,7 +85,7 @@ __global__ void setup_topk_runtime_args(int batch_size, uint32_t top_k, uint32_t
|
||||
template <typename T>
|
||||
void TopKSamplingLayer<T>::allocateBuffer(size_t const batch_size, std::vector<uint32_t> const& top_k)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
uint32_t max_top_k = (top_k.size() > 0) ? *std::max_element(std::begin(top_k), std::end(top_k)) : 1;
|
||||
if (max_top_k == 0)
|
||||
{
|
||||
@ -104,7 +104,7 @@ void TopKSamplingLayer<T>::allocateBuffer(size_t const batch_size, std::vector<u
|
||||
template <typename T>
|
||||
void TopKSamplingLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
if (is_allocate_buffer_)
|
||||
{
|
||||
allocator_->free((void**) (&sampling_workspace_));
|
||||
@ -118,7 +118,7 @@ void TopKSamplingLayer<T>::freeBuffer()
|
||||
template <typename T>
|
||||
void TopKSamplingLayer<T>::setup(size_t const batch_size, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
BaseSamplingLayer<T>::setupBase(batch_size, setupParams);
|
||||
|
||||
uint32_t const default_top_k = 0;
|
||||
@ -162,7 +162,7 @@ void TopKSamplingLayer<T>::setup(size_t const batch_size, SetupParams const& set
|
||||
template <typename T>
|
||||
void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& params)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batch_size = outputs.output_ids_ptr.shape[0];
|
||||
auto const local_batch_size = params.logits.shape[0];
|
||||
@ -223,7 +223,7 @@ TopKSamplingLayer<T>::TopKSamplingLayer(TopKSamplingLayer<T> const& top_k_sampli
|
||||
template <typename T>
|
||||
TopKSamplingLayer<T>::~TopKSamplingLayer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
freeBuffer();
|
||||
}
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ static __global__ void set_topp_runtime_args(int batch_size, std::uint32_t top_k
|
||||
template <typename T>
|
||||
void TopPSamplingLayer<T>::allocateBuffer(std::size_t batch_size, std::vector<float> const& top_p)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
float const max_top_p = (top_p.size() > 0) ? *std::max_element(std::begin(top_p), std::end(top_p)) : 0.0f;
|
||||
invokeTopPSampling<T>(nullptr, // workspace
|
||||
sampling_workspace_size_, cub_temp_storage_size_,
|
||||
@ -136,7 +136,7 @@ void TopPSamplingLayer<T>::allocateBuffer(std::size_t batch_size, std::vector<fl
|
||||
template <typename T>
|
||||
void TopPSamplingLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
if (is_allocate_buffer_)
|
||||
{
|
||||
allocator_->free((void**) (&sampling_workspace_));
|
||||
@ -157,7 +157,7 @@ void TopPSamplingLayer<T>::freeBuffer()
|
||||
template <typename T>
|
||||
void TopPSamplingLayer<T>::setup(std::size_t const batch_size, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
BaseSamplingLayer<T>::setupBase(batch_size, setupParams);
|
||||
|
||||
std::uint32_t const default_top_k = 0;
|
||||
@ -229,7 +229,7 @@ void TopPSamplingLayer<T>::setup(std::size_t const batch_size, SetupParams const
|
||||
template <typename T>
|
||||
void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& params)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
auto const batch_size = outputs.output_ids_ptr.shape[0];
|
||||
auto const local_batch_size = params.logits.shape[0];
|
||||
@ -288,7 +288,7 @@ TopPSamplingLayer<T>::TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampli
|
||||
template <typename T>
|
||||
TopPSamplingLayer<T>::~TopPSamplingLayer()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
freeBuffer();
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user