mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5302040][feat] Add whisper support (Bert Attention on SM100 and GPTAttention for cross attention on SM100) (#5527)
Signed-off-by: tinyinl <tinyinl@nvidia.com>
This commit is contained in:
parent
bda42f8c3a
commit
6c52bb07ff
@ -520,7 +520,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
|
||||
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
mFMHARunner->run(fmhaParams);
|
||||
mFmhaDispatcher->run(fmhaParams);
|
||||
if (iter != 0)
|
||||
{
|
||||
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
|
||||
@ -704,7 +704,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
}
|
||||
|
||||
// Run the fmha kernel.
|
||||
mFMHARunner->run(fmhaParams);
|
||||
|
||||
// TODO: set it correctly for contiguous kv buffer (cross-attention).
|
||||
fmhaParams.totalKvSeqLen = num_tokens;
|
||||
|
||||
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
|
||||
fmhaParams.cuMaskRowsPtr = cu_seqlens;
|
||||
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
|
||||
|
||||
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
|
||||
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
|
||||
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
|
||||
mFmhaDispatcher->run(fmhaParams);
|
||||
sync_check_cuda_error(stream);
|
||||
if (mSageAttn)
|
||||
{
|
||||
@ -948,10 +959,14 @@ int BertAttentionPlugin::initialize() noexcept
|
||||
}
|
||||
|
||||
// Load kernels from the pre-compiled cubins.
|
||||
mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams));
|
||||
// The KV input data type. The default is same as dataType.
|
||||
fmhaParams.dataTypeKv = data_type;
|
||||
fmhaParams.headSizeV = mHeadSize;
|
||||
|
||||
// Load kernels from the pre-compiled cubins.
|
||||
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
|
||||
// Fall back to unfused MHA kernels if not supported.
|
||||
mEnableContextFMHA = mFMHARunner->isFmhaSupported();
|
||||
mEnableContextFMHA = mFmhaDispatcher->isSupported();
|
||||
}
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h"
|
||||
#include "tensorrt_llm/kernels/fmhaDispatcher.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
@ -114,7 +114,7 @@ private:
|
||||
cudaStream_t mNcclStream;
|
||||
|
||||
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
|
||||
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
|
||||
UniqPtrWNullCopy<tensorrt_llm::kernels::FmhaDispatcher> mFmhaDispatcher;
|
||||
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
|
||||
};
|
||||
|
||||
|
||||
@ -30,9 +30,9 @@ from . import graph_rewriting as gw
|
||||
from ._common import default_net, default_trtnet, precision
|
||||
from ._utils import (QuantModeWrapper, bf16_array, bool_array,
|
||||
dim_resolve_negative, dim_to_trt_axes, dims_array,
|
||||
fp16_array, fp32_array, int32_array, int64_array,
|
||||
np_dtype_to_trt, str_dtype_to_trt, trt_dtype_to_np,
|
||||
trt_dtype_to_str)
|
||||
fp16_array, fp32_array, get_sm_version, int32_array,
|
||||
int64_array, np_dtype_to_trt, str_dtype_to_trt,
|
||||
trt_dtype_to_np, trt_dtype_to_str)
|
||||
from .network import PluginInfo, set_np_weight, set_plugin_info
|
||||
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
|
||||
from .quantization import QuantMode
|
||||
@ -5719,7 +5719,8 @@ def gpt_attention(
|
||||
if (attention_mask is not None) or (attention_packed_mask is not None):
|
||||
# context fmha needs packed mask.
|
||||
assert attention_packed_mask is not None
|
||||
mask_type = AttentionMaskType.custom_mask
|
||||
if get_sm_version() < 100:
|
||||
mask_type = AttentionMaskType.custom_mask
|
||||
|
||||
mask_type_filed = trt.PluginField("mask_type",
|
||||
np.array([int(mask_type)], np.int32),
|
||||
@ -5844,7 +5845,7 @@ def gpt_attention(
|
||||
if attention_mask is not None and mask_type == AttentionMaskType.custom_mask:
|
||||
# useFullCustomMask
|
||||
plug_inputs += [attention_mask]
|
||||
if attention_packed_mask is not None:
|
||||
if attention_packed_mask is not None and get_sm_version() < 100:
|
||||
# usePackedCustomMask
|
||||
plug_inputs += [attention_packed_mask]
|
||||
if use_cache:
|
||||
|
||||
@ -20,8 +20,8 @@ import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from .._common import default_net, precision
|
||||
from .._utils import (fp32_array, get_sm_version, int32_array, is_same_dtype,
|
||||
set_obj_attrs, trt_dtype_to_np, trt_dtype_to_str)
|
||||
from .._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs,
|
||||
trt_dtype_to_np, trt_dtype_to_str)
|
||||
|
||||
# isort: off
|
||||
from ..functional import (
|
||||
@ -1755,8 +1755,6 @@ class BertAttention(Module):
|
||||
if default_net().plugin_config.bert_attention_plugin:
|
||||
# TRT plugin mode
|
||||
assert input_lengths is not None
|
||||
assert get_sm_version() < 100 or get_sm_version() >= 120, \
|
||||
"bert_attention_plugin does not support SM100"
|
||||
context = bert_attention(
|
||||
qkv,
|
||||
input_lengths,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user