mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5302040][feat] Add whisper support (Bert Attention on SM100 and GPTAttention for cross attention on SM100) (#5527)
Signed-off-by: tinyinl <tinyinl@nvidia.com>
This commit is contained in:
parent
bda42f8c3a
commit
6c52bb07ff
@ -520,7 +520,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
|||||||
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
|
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
|
||||||
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
|
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
|
||||||
cudaMemcpyHostToDevice, stream);
|
cudaMemcpyHostToDevice, stream);
|
||||||
mFMHARunner->run(fmhaParams);
|
mFmhaDispatcher->run(fmhaParams);
|
||||||
if (iter != 0)
|
if (iter != 0)
|
||||||
{
|
{
|
||||||
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
|
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
|
||||||
@ -704,7 +704,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run the fmha kernel.
|
// Run the fmha kernel.
|
||||||
mFMHARunner->run(fmhaParams);
|
|
||||||
|
// TODO: set it correctly for contiguous kv buffer (cross-attention).
|
||||||
|
fmhaParams.totalKvSeqLen = num_tokens;
|
||||||
|
|
||||||
|
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
|
||||||
|
fmhaParams.cuMaskRowsPtr = cu_seqlens;
|
||||||
|
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
|
||||||
|
|
||||||
|
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
|
||||||
|
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
|
||||||
|
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
|
||||||
|
mFmhaDispatcher->run(fmhaParams);
|
||||||
sync_check_cuda_error(stream);
|
sync_check_cuda_error(stream);
|
||||||
if (mSageAttn)
|
if (mSageAttn)
|
||||||
{
|
{
|
||||||
@ -948,10 +959,14 @@ int BertAttentionPlugin::initialize() noexcept
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load kernels from the pre-compiled cubins.
|
// Load kernels from the pre-compiled cubins.
|
||||||
mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams));
|
// The KV input data type. The default is same as dataType.
|
||||||
|
fmhaParams.dataTypeKv = data_type;
|
||||||
|
fmhaParams.headSizeV = mHeadSize;
|
||||||
|
|
||||||
|
// Load kernels from the pre-compiled cubins.
|
||||||
|
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
|
||||||
// Fall back to unfused MHA kernels if not supported.
|
// Fall back to unfused MHA kernels if not supported.
|
||||||
mEnableContextFMHA = mFMHARunner->isFmhaSupported();
|
mEnableContextFMHA = mFmhaDispatcher->isSupported();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if ENABLE_MULTI_DEVICE
|
#if ENABLE_MULTI_DEVICE
|
||||||
|
|||||||
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||||
#include "tensorrt_llm/common/quantization.h"
|
#include "tensorrt_llm/common/quantization.h"
|
||||||
#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h"
|
#include "tensorrt_llm/kernels/fmhaDispatcher.h"
|
||||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||||
@ -114,7 +114,7 @@ private:
|
|||||||
cudaStream_t mNcclStream;
|
cudaStream_t mNcclStream;
|
||||||
|
|
||||||
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
|
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
|
||||||
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
|
UniqPtrWNullCopy<tensorrt_llm::kernels::FmhaDispatcher> mFmhaDispatcher;
|
||||||
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
|
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -30,9 +30,9 @@ from . import graph_rewriting as gw
|
|||||||
from ._common import default_net, default_trtnet, precision
|
from ._common import default_net, default_trtnet, precision
|
||||||
from ._utils import (QuantModeWrapper, bf16_array, bool_array,
|
from ._utils import (QuantModeWrapper, bf16_array, bool_array,
|
||||||
dim_resolve_negative, dim_to_trt_axes, dims_array,
|
dim_resolve_negative, dim_to_trt_axes, dims_array,
|
||||||
fp16_array, fp32_array, int32_array, int64_array,
|
fp16_array, fp32_array, get_sm_version, int32_array,
|
||||||
np_dtype_to_trt, str_dtype_to_trt, trt_dtype_to_np,
|
int64_array, np_dtype_to_trt, str_dtype_to_trt,
|
||||||
trt_dtype_to_str)
|
trt_dtype_to_np, trt_dtype_to_str)
|
||||||
from .network import PluginInfo, set_np_weight, set_plugin_info
|
from .network import PluginInfo, set_np_weight, set_plugin_info
|
||||||
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
|
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
|
||||||
from .quantization import QuantMode
|
from .quantization import QuantMode
|
||||||
@ -5719,7 +5719,8 @@ def gpt_attention(
|
|||||||
if (attention_mask is not None) or (attention_packed_mask is not None):
|
if (attention_mask is not None) or (attention_packed_mask is not None):
|
||||||
# context fmha needs packed mask.
|
# context fmha needs packed mask.
|
||||||
assert attention_packed_mask is not None
|
assert attention_packed_mask is not None
|
||||||
mask_type = AttentionMaskType.custom_mask
|
if get_sm_version() < 100:
|
||||||
|
mask_type = AttentionMaskType.custom_mask
|
||||||
|
|
||||||
mask_type_filed = trt.PluginField("mask_type",
|
mask_type_filed = trt.PluginField("mask_type",
|
||||||
np.array([int(mask_type)], np.int32),
|
np.array([int(mask_type)], np.int32),
|
||||||
@ -5844,7 +5845,7 @@ def gpt_attention(
|
|||||||
if attention_mask is not None and mask_type == AttentionMaskType.custom_mask:
|
if attention_mask is not None and mask_type == AttentionMaskType.custom_mask:
|
||||||
# useFullCustomMask
|
# useFullCustomMask
|
||||||
plug_inputs += [attention_mask]
|
plug_inputs += [attention_mask]
|
||||||
if attention_packed_mask is not None:
|
if attention_packed_mask is not None and get_sm_version() < 100:
|
||||||
# usePackedCustomMask
|
# usePackedCustomMask
|
||||||
plug_inputs += [attention_packed_mask]
|
plug_inputs += [attention_packed_mask]
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|||||||
@ -20,8 +20,8 @@ import tensorrt as trt
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .._common import default_net, precision
|
from .._common import default_net, precision
|
||||||
from .._utils import (fp32_array, get_sm_version, int32_array, is_same_dtype,
|
from .._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs,
|
||||||
set_obj_attrs, trt_dtype_to_np, trt_dtype_to_str)
|
trt_dtype_to_np, trt_dtype_to_str)
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
from ..functional import (
|
from ..functional import (
|
||||||
@ -1755,8 +1755,6 @@ class BertAttention(Module):
|
|||||||
if default_net().plugin_config.bert_attention_plugin:
|
if default_net().plugin_config.bert_attention_plugin:
|
||||||
# TRT plugin mode
|
# TRT plugin mode
|
||||||
assert input_lengths is not None
|
assert input_lengths is not None
|
||||||
assert get_sm_version() < 100 or get_sm_version() >= 120, \
|
|
||||||
"bert_attention_plugin does not support SM100"
|
|
||||||
context = bert_attention(
|
context = bert_attention(
|
||||||
qkv,
|
qkv,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user