* fp8 kv + bf16 ctx MLA + fp8 gen MLA
Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.
Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.
For FP8 MLA generation, the output is still in BF16.
Refine debug info for FMHA kernel metadata.
Use inputType, outputType, SM together to hash kernel list.
Add FP8 MLA generation FMHA kernel.
Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.
Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.
Refine debug info in fused_multihead_attention_v2.cpp
Correct FP8 MLA metadata.
New kernel provided by Yuxin, which outputs BF16.
smem size is not set correctly, which will lead to illegal mem access.
Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.
There are two bmm1 scales that should be set correctly.
New kernel generated by Yuxin.
Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.
Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.
Skip a check in fmhaDispatcher.
Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).
Cleanup debug output.
Clean up o tma descriptor modifications.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Apply the patch of FP8 FlashMLA and resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compilation error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compile error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* pick blackwell support
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
* Add copyright notice to fused_multihead_attention_v2.cpp.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add missing license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Exclude building flashMLA kernels under sm90.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Revert "Exclude building flashMLA kernels under sm90."
This reverts commit f0c859d459.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Use macro to skip compiling FlashMLA for non sm90 targets.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
---------
Signed-off-by: Bo Li <bobboli0202@gmail.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Dylan Chen <ziqingc@nvidia.com>
Co-authored-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
* init trtllm attn no cache
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* fix: fix the seq_len issue and attn metadata prepare for qwen reward model test
fix: fix minor bugs after rebase
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: remove unnecessary debug logs and clean up commented code
refactor: update max_seq_len documentation and remove max_seq_len for decoder model contructor in PyTorchModelEngine
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: update calculate_ref_result function to accept tensor inputs and mask type, enhance test_attention_no_cache to support FULL and CAUSAL masks
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: remove unused BERT attention metadata conversion method and add type assertion for no cache attention in PyTorchModelEngine
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: remove use_kv_cache parameter from attention function and related classes, update documentation for KV cache handling
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: implement setAttentionMaskType method for better mask type handling and remove unused conversion function
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: streamline KV cache handling by replacing direct member access with useKVCache method and simplify token per block assignment
remove Debug code.
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: Resolve comments for Python code
Simplify no cache attention metadata preparation and streamline related attributes in TrtllmAttentionMetadata
Removed the private method for converting to no cache attention metadata and integrated its logic into the prepare method. Updated the test for BERT sequence classification to reflect these changes and ensure proper handling of attention metadata.
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* docs: Add is_dummy_attention field to attention metadata for simulation operations
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: add KVCacheParams to attention backend interface and import relevant metadata classes
Updated the attention backend interface to include KVCacheParams and imported TrtllmAttentionMetadata and VanillaAttentionMetadata in model_engine.py for enhanced functionality.
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* fix: fix rebase format issue
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* fix: extend attention mask type handling in MHARunnerFixedParams
Added support for additional attention mask types (BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE) in the MHARunnerFixedParams structure to fix the mapping issue between ContextAttentionMaskType and AttentionMaskType
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* fix: enhance attention mask type handling in TllmGenFmhaRunnerParams
Updated the setAttentionMaskType method to include a switch-case structure for better handling of attention mask types, ensuring proper mapping and error handling for invalid types.
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
---------
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
* refactor: Update gatherTree function to accept CUDA stream parameter
This commit modifies the gatherTree function signature to include a runtime::CudaStream parameter, enhancing flexibility in stream management. Additionally, it removes unnecessary buffer manager parameters and stream handling from the function, streamlining the code. The finalize method in GptDecoderBatched is also updated to reflect these changes, improving clarity and maintainability in the decoding process.
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
* refactor: Update GptDecoderBatched finalize
This commit refactors the GptDecoderBatched class to improve method signatures and reduce code complexity:
- Modified finalize method to accept DecoderState as a parameter
- Updated method signatures to work with the new DecoderState approach
- Improved code organization and readability
The changes continue the ongoing refactoring to centralize decoder state management and simplify the decoder implementation.
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
---------
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
* chore: update cutlass to v3.8.0
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
* refactor: update include directives for consistency and organization in weightOnlyBatchedGemv headers
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
* Fix fpA_intB_gemm compilation
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
---------
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>