[https://nvbugs/5302040][feat] Add whisper support (Bert Attention on SM100 and GPTAttention for cross attention on SM100) (#5527)

Signed-off-by: tinyinl <tinyinl@nvidia.com>
This commit is contained in:
Tin-Yin Lai 2025-08-13 11:19:13 -07:00 committed by GitHub
parent bda42f8c3a
commit 6c52bb07ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 15 deletions

View File

@ -520,7 +520,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
cudaMemcpyHostToDevice, stream);
mFMHARunner->run(fmhaParams);
mFmhaDispatcher->run(fmhaParams);
if (iter != 0)
{
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
@ -704,7 +704,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
}
// Run the fmha kernel.
mFMHARunner->run(fmhaParams);
// TODO: set it correctly for contiguous kv buffer (cross-attention).
fmhaParams.totalKvSeqLen = num_tokens;
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
fmhaParams.cuMaskRowsPtr = cu_seqlens;
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
mFmhaDispatcher->run(fmhaParams);
sync_check_cuda_error(stream);
if (mSageAttn)
{
@ -948,10 +959,14 @@ int BertAttentionPlugin::initialize() noexcept
}
// Load kernels from the pre-compiled cubins.
mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams));
// The KV input data type. The default is same as dataType.
fmhaParams.dataTypeKv = data_type;
fmhaParams.headSizeV = mHeadSize;
// Load kernels from the pre-compiled cubins.
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
// Fall back to unfused MHA kernels if not supported.
mEnableContextFMHA = mFMHARunner->isFmhaSupported();
mEnableContextFMHA = mFmhaDispatcher->isSupported();
}
#if ENABLE_MULTI_DEVICE

View File

@ -18,7 +18,7 @@
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h"
#include "tensorrt_llm/kernels/fmhaDispatcher.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
@ -114,7 +114,7 @@ private:
cudaStream_t mNcclStream;
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
UniqPtrWNullCopy<tensorrt_llm::kernels::FmhaDispatcher> mFmhaDispatcher;
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
};

View File

@ -30,9 +30,9 @@ from . import graph_rewriting as gw
from ._common import default_net, default_trtnet, precision
from ._utils import (QuantModeWrapper, bf16_array, bool_array,
dim_resolve_negative, dim_to_trt_axes, dims_array,
fp16_array, fp32_array, int32_array, int64_array,
np_dtype_to_trt, str_dtype_to_trt, trt_dtype_to_np,
trt_dtype_to_str)
fp16_array, fp32_array, get_sm_version, int32_array,
int64_array, np_dtype_to_trt, str_dtype_to_trt,
trt_dtype_to_np, trt_dtype_to_str)
from .network import PluginInfo, set_np_weight, set_plugin_info
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
from .quantization import QuantMode
@ -5719,7 +5719,8 @@ def gpt_attention(
if (attention_mask is not None) or (attention_packed_mask is not None):
# context fmha needs packed mask.
assert attention_packed_mask is not None
mask_type = AttentionMaskType.custom_mask
if get_sm_version() < 100:
mask_type = AttentionMaskType.custom_mask
mask_type_filed = trt.PluginField("mask_type",
np.array([int(mask_type)], np.int32),
@ -5844,7 +5845,7 @@ def gpt_attention(
if attention_mask is not None and mask_type == AttentionMaskType.custom_mask:
# useFullCustomMask
plug_inputs += [attention_mask]
if attention_packed_mask is not None:
if attention_packed_mask is not None and get_sm_version() < 100:
# usePackedCustomMask
plug_inputs += [attention_packed_mask]
if use_cache:

View File

@ -20,8 +20,8 @@ import tensorrt as trt
import torch
from .._common import default_net, precision
from .._utils import (fp32_array, get_sm_version, int32_array, is_same_dtype,
set_obj_attrs, trt_dtype_to_np, trt_dtype_to_str)
from .._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs,
trt_dtype_to_np, trt_dtype_to_str)
# isort: off
from ..functional import (
@ -1755,8 +1755,6 @@ class BertAttention(Module):
if default_net().plugin_config.bert_attention_plugin:
# TRT plugin mode
assert input_lengths is not None
assert get_sm_version() < 100 or get_sm_version() >= 120, \
"bert_attention_plugin does not support SM100"
context = bert_attention(
qkv,
input_lengths,