TensorRT-LLMs/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h
Zhou Yuxin fca13b8c95
hopper-style context MLA (#5713)
Signed-off-by: Yuxin <yuxinz@nvidia.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: Rashid K <rkaleem@nvidia.com>
Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
Signed-off-by: Po-Wei Wang (Vincent) <poweiw@nvidia.com>
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Clay <ccs96307@gmail.com>
Signed-off-by: Venky <23023424+venkywonka@users.noreply.github.com>
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Signed-off-by: Hui Gao <huig@nvidia.com>
Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
Signed-off-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com>
Signed-off-by: Julien Debache <julien.debache@hotmail.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Signed-off-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com>
Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: David Clark <215764518+davidclark-nv@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Signed-off-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
Signed-off-by: Yegor <75512761+Wokzy@users.noreply.github.com>
Signed-off-by: Yegor Yershov <yegor6741@gmail.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: raayandhar <rdhar@nvidia.com>
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Signed-off-by: xsimmons <xsimmons@nvidia.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal>
Signed-off-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com>
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: narutolhy <582909902@qq.com>
Signed-off-by: ZhanruiSunCh <184402041+ZhanruiSunCh@users.noreply.github.com>
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Signed-off-by: Frank <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
Signed-off-by: William Tambellini <wtambellini@sdl.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Yiqing Yan <yiqingy@nvidia.com>
Co-authored-by: Emma Qiao <qqiao@nvidia.com>
Co-authored-by: WeiHaocheng <20514172+WeiHaocheng@users.noreply.github.com>
Co-authored-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Co-authored-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Co-authored-by: Rashid Kaleem <4079439+arekay@users.noreply.github.com>
Co-authored-by: Zhihan Jiang <68881590+nvzhihanj@users.noreply.github.com>
Co-authored-by: Zhenhuan Chen <chenzhh3671@gmail.com>
Co-authored-by: Po-Wei (Vincent) <poweiw@nvidia.com>
Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Neta Zmora <nzmora@nvidia.com>
Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: Clay <ccs96307@gmail.com>
Co-authored-by: Venky <23023424+venkywonka@users.noreply.github.com>
Co-authored-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Co-authored-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Co-authored-by: Zheng Duan <200704041+zhengd-nv@users.noreply.github.com>
Co-authored-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Co-authored-by: Frank <3429989+FrankD412@users.noreply.github.com>
Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com>
Co-authored-by: Linda <57756729+Linda-Stadter@users.noreply.github.com>
Co-authored-by: Shunkangz <182541032+Shunkangz@users.noreply.github.com>
Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: Tailing Yuan <yuantailing@gmail.com>
Co-authored-by: Faraz <58580514+farazkh80@users.noreply.github.com>
Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
Co-authored-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Co-authored-by: HuiGao-NV <huig@nvidia.com>
Co-authored-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Co-authored-by: jthomson04 <jwillthomson19@gmail.com>
Co-authored-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Julien Debache <jdebache@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com>
Co-authored-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com>
Co-authored-by: Larry <197874197+LarryXFly@users.noreply.github.com>
Co-authored-by: DylanChen-NV <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
Co-authored-by: davidclark-nv <215764518+davidclark-nv@users.noreply.github.com>
Co-authored-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: liji-nv <59594262+liji-nv@users.noreply.github.com>
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: xiweny <13230610+VALLIS-NERIA@users.noreply.github.com>
Co-authored-by: Yegor <75512761+Wokzy@users.noreply.github.com>
Co-authored-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Co-authored-by: Raayan Dhar <58057652+raayandhar@users.noreply.github.com>
Co-authored-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Co-authored-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
Co-authored-by: xavier-nvidia <xsimmons@nvidia.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Co-authored-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Co-authored-by: Erin <14718778+hchings@users.noreply.github.com>
Co-authored-by: chenfeiz0326 <chenfeiz@nvidia.com>
Co-authored-by: dongxuy04 <78518666+dongxuy04@users.noreply.github.com>
Co-authored-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com>
Co-authored-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
Co-authored-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Co-authored-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Co-authored-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: narutolhy <582909902@qq.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: Void <18275976+yilin-void@users.noreply.github.com>
Co-authored-by: William Tambellini <wtambellini@sdl.com>
2025-07-23 14:37:20 +08:00

106 lines
3.2 KiB
C++

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <cstring>
#include <iostream>
#include <memory>
#include <tuple>
#include <vector>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "fused_multihead_attention_common.h"
#include "fused_multihead_attention_v2.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tmaDescriptor.h"
namespace tensorrt_llm
{
namespace kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Workflow of fmha runner:
// 1. check if FMHA kernels are supported statically.
// 2. construct FMHA runner object with the fixed params.
// 3. run the kernel (with all needed device pointers).
class FusedMHARunnerV2
{
public:
// Constructor.
FusedMHARunnerV2(MHARunnerFixedParams fixedParams);
// Deconstructor.
~FusedMHARunnerV2() = default; // for pimpl
// Check if any fmha kernel meets the requirements.
bool isFmhaSupported();
// Does FMHA need a separate Q and Kv input ?
bool isSeparateQAndKvInput() const
{
return mFixedParams.attentionInputLayout != AttentionInputLayout::PACKED_QKV;
}
// Run the fmha kernel.
void run(MHARunnerParams runnerParams);
private:
// Set the kernel params.
void setupKernelParams(MHARunnerParams runnerParams);
// Set the launch params to select kernels.
void setupLaunchParams(MHARunnerParams runnerParams);
// Set the tma descriptors.
void setTmaDescriptors(MHARunnerParams runnerParams);
// Check if it is a valid sequence length (only used by non-flash-attention kernels).
bool isValidS(int s) const;
// Get the kernel sequence that support the max sequence length (only used by non-flash-attention kernels).
int getSFromMaxSeqLen(int const max_seq_len) const;
private:
// The attention fixed params (mostly related to the attention structure).
MHARunnerFixedParams mFixedParams;
// The attention input params (runtime-known parameters).
MHARunnerParams mRunnerParams;
// The launch params to select the specific fmha kernel.
Launch_params mLaunchParams;
// The kernel params.
Fused_multihead_attention_params_v2 mKernelParams;
// The SM version.
int mSM = tensorrt_llm::common::getSMVersion();
// The multiple processor count.
int mMultiProcessorCount;
// The L2 cache size.
int mDeviceL2CacheSize;
// The total device memory.
size_t mTotalDeviceMemory;
// The class that stores all the kernels.
FusedMultiHeadAttentionXMMAKernelV2 const* xmmaKernel;
};
} // namespace kernels
} // namespace tensorrt_llm