mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yuxin <yuxinz@nvidia.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Yiqing Yan <yiqingy@nvidia.com> Signed-off-by: qqiao <qqiao@nvidia.com> Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com> Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com> Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Rashid K <rkaleem@nvidia.com> Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com> Signed-off-by: Po-Wei Wang (Vincent) <poweiw@nvidia.com> Signed-off-by: Netanel Haber <nhaber@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Clay <ccs96307@gmail.com> Signed-off-by: Venky <23023424+venkywonka@users.noreply.github.com> Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com> Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Signed-off-by: Tailing Yuan <yuantailing@gmail.com> Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com> Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> Signed-off-by: Hui Gao <huig@nvidia.com> Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> Signed-off-by: jthomson04 <jwillthomson19@gmail.com> Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com> Signed-off-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com> Signed-off-by: Julien Debache <julien.debache@hotmail.com> Signed-off-by: Yanchao Lu <yanchaol@nvidia.com> Signed-off-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com> Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: David Clark <215764518+davidclark-nv@users.noreply.github.com> Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Signed-off-by: JieXin Liang <Alcanderian@users.noreply.github.com> Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com> Signed-off-by: Yegor <75512761+Wokzy@users.noreply.github.com> Signed-off-by: Yegor Yershov <yegor6741@gmail.com> Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: raayandhar <rdhar@nvidia.com> Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com> Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Signed-off-by: xsimmons <xsimmons@nvidia.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com> Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal> Signed-off-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com> Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com> Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com> Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Signed-off-by: narutolhy <582909902@qq.com> Signed-off-by: ZhanruiSunCh <184402041+ZhanruiSunCh@users.noreply.github.com> Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> Signed-off-by: Frank <3429989+FrankD412@users.noreply.github.com> Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com> Signed-off-by: William Tambellini <wtambellini@sdl.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Yiqing Yan <yiqingy@nvidia.com> Co-authored-by: Emma Qiao <qqiao@nvidia.com> Co-authored-by: WeiHaocheng <20514172+WeiHaocheng@users.noreply.github.com> Co-authored-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com> Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Co-authored-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Co-authored-by: Rashid Kaleem <4079439+arekay@users.noreply.github.com> Co-authored-by: Zhihan Jiang <68881590+nvzhihanj@users.noreply.github.com> Co-authored-by: Zhenhuan Chen <chenzhh3671@gmail.com> Co-authored-by: Po-Wei (Vincent) <poweiw@nvidia.com> Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: Neta Zmora <nzmora@nvidia.com> Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Co-authored-by: Clay <ccs96307@gmail.com> Co-authored-by: Venky <23023424+venkywonka@users.noreply.github.com> Co-authored-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Co-authored-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Co-authored-by: Zheng Duan <200704041+zhengd-nv@users.noreply.github.com> Co-authored-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Co-authored-by: Frank <3429989+FrankD412@users.noreply.github.com> Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com> Co-authored-by: Linda <57756729+Linda-Stadter@users.noreply.github.com> Co-authored-by: Shunkangz <182541032+Shunkangz@users.noreply.github.com> Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Co-authored-by: Tailing Yuan <yuantailing@gmail.com> Co-authored-by: Faraz <58580514+farazkh80@users.noreply.github.com> Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com> Co-authored-by: ixlmar <206748156+ixlmar@users.noreply.github.com> Co-authored-by: HuiGao-NV <huig@nvidia.com> Co-authored-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Co-authored-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> Co-authored-by: jthomson04 <jwillthomson19@gmail.com> Co-authored-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Julien Debache <jdebache@nvidia.com> Co-authored-by: Yanchao Lu <yanchaol@nvidia.com> Co-authored-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com> Co-authored-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com> Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com> Co-authored-by: Larry <197874197+LarryXFly@users.noreply.github.com> Co-authored-by: DylanChen-NV <191843203+DylanChen-NV@users.noreply.github.com> Co-authored-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com> Co-authored-by: davidclark-nv <215764518+davidclark-nv@users.noreply.github.com> Co-authored-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com> Co-authored-by: liji-nv <59594262+liji-nv@users.noreply.github.com> Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com> Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Co-authored-by: xiweny <13230610+VALLIS-NERIA@users.noreply.github.com> Co-authored-by: Yegor <75512761+Wokzy@users.noreply.github.com> Co-authored-by: Yukun He <23156053+hyukn@users.noreply.github.com> Co-authored-by: Raayan Dhar <58057652+raayandhar@users.noreply.github.com> Co-authored-by: Dom Brown <3886319+DomBrown@users.noreply.github.com> Co-authored-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com> Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Co-authored-by: xavier-nvidia <xsimmons@nvidia.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com> Co-authored-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Co-authored-by: Erin <14718778+hchings@users.noreply.github.com> Co-authored-by: chenfeiz0326 <chenfeiz@nvidia.com> Co-authored-by: dongxuy04 <78518666+dongxuy04@users.noreply.github.com> Co-authored-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com> Co-authored-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com> Co-authored-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com> Co-authored-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com> Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com> Co-authored-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com> Co-authored-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Co-authored-by: narutolhy <582909902@qq.com> Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com> Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com> Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com> Co-authored-by: Void <18275976+yilin-void@users.noreply.github.com> Co-authored-by: William Tambellini <wtambellini@sdl.com>
106 lines
3.2 KiB
C++
106 lines
3.2 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <cassert>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <tuple>
|
|
#include <vector>
|
|
|
|
#include <cublas_v2.h>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
#include "fused_multihead_attention_common.h"
|
|
#include "fused_multihead_attention_v2.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tmaDescriptor.h"
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Workflow of fmha runner:
|
|
// 1. check if FMHA kernels are supported statically.
|
|
// 2. construct FMHA runner object with the fixed params.
|
|
// 3. run the kernel (with all needed device pointers).
|
|
class FusedMHARunnerV2
|
|
{
|
|
public:
|
|
// Constructor.
|
|
FusedMHARunnerV2(MHARunnerFixedParams fixedParams);
|
|
|
|
// Deconstructor.
|
|
~FusedMHARunnerV2() = default; // for pimpl
|
|
|
|
// Check if any fmha kernel meets the requirements.
|
|
bool isFmhaSupported();
|
|
|
|
// Does FMHA need a separate Q and Kv input ?
|
|
bool isSeparateQAndKvInput() const
|
|
{
|
|
return mFixedParams.attentionInputLayout != AttentionInputLayout::PACKED_QKV;
|
|
}
|
|
|
|
// Run the fmha kernel.
|
|
void run(MHARunnerParams runnerParams);
|
|
|
|
private:
|
|
// Set the kernel params.
|
|
void setupKernelParams(MHARunnerParams runnerParams);
|
|
|
|
// Set the launch params to select kernels.
|
|
void setupLaunchParams(MHARunnerParams runnerParams);
|
|
|
|
// Set the tma descriptors.
|
|
void setTmaDescriptors(MHARunnerParams runnerParams);
|
|
|
|
// Check if it is a valid sequence length (only used by non-flash-attention kernels).
|
|
bool isValidS(int s) const;
|
|
|
|
// Get the kernel sequence that support the max sequence length (only used by non-flash-attention kernels).
|
|
int getSFromMaxSeqLen(int const max_seq_len) const;
|
|
|
|
private:
|
|
// The attention fixed params (mostly related to the attention structure).
|
|
MHARunnerFixedParams mFixedParams;
|
|
// The attention input params (runtime-known parameters).
|
|
MHARunnerParams mRunnerParams;
|
|
// The launch params to select the specific fmha kernel.
|
|
Launch_params mLaunchParams;
|
|
// The kernel params.
|
|
Fused_multihead_attention_params_v2 mKernelParams;
|
|
// The SM version.
|
|
int mSM = tensorrt_llm::common::getSMVersion();
|
|
// The multiple processor count.
|
|
int mMultiProcessorCount;
|
|
// The L2 cache size.
|
|
int mDeviceL2CacheSize;
|
|
// The total device memory.
|
|
size_t mTotalDeviceMemory;
|
|
// The class that stores all the kernels.
|
|
FusedMultiHeadAttentionXMMAKernelV2 const* xmmaKernel;
|
|
};
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|