mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-27 22:23:25 +08:00
Merge remote-tracking branch 'gitlab/main' into user/xiweny/merge_main_0819
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
8b532363ce
@ -20,7 +20,7 @@ language: "en-US"
|
||||
reviews:
|
||||
profile: chill
|
||||
auto_title_placeholder: '@coderabbitai title'
|
||||
auto_title_instructions: 'Should follow the format: "[fix/feat/doc/infra/...] \<summary of this PR\>". Keep it concise.'
|
||||
auto_title_instructions: 'Format: "[<category>] <title>". Category must be one of: fix, feat, doc, infra, style, refactor, perf, test, chore, revert. Enclose the category in square brackets. Title should be concise (<= 60 chars). Example: "[feat] Add logit_bias support".'
|
||||
commit_status: false
|
||||
collapse_walkthrough: true
|
||||
assess_linked_issues: true
|
||||
|
||||
24
.editorconfig
Normal file
24
.editorconfig
Normal file
@ -0,0 +1,24 @@
|
||||
# Auto basic formatting when saving file with EditorConfig https://editorconfig.org/
|
||||
|
||||
# top-most EditorConfig file
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
# make
|
||||
[Makefile*]
|
||||
indent_style = tab
|
||||
indent_size = 4
|
||||
|
||||
# c++
|
||||
[*.{cpp,cu,h}]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
|
||||
# python
|
||||
[*.py]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
54
.github/CODEOWNERS
vendored
54
.github/CODEOWNERS
vendored
@ -6,13 +6,39 @@
|
||||
# Without approval from a member of this team, PRs cannot be merged to release branches.
|
||||
# * @NVIDIA/trt-llm-release-branch-approval
|
||||
|
||||
## TensorRT-LLM Infra
|
||||
### CI
|
||||
/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
### Setup
|
||||
/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
### Github workflows
|
||||
/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
|
||||
## TensorRT-LLM - Docs
|
||||
/docs @NVIDIA/trt-llm-doc-owners
|
||||
|
||||
## Examples
|
||||
/examples @NVIDIA/trt-llm-doc-owners
|
||||
|
||||
## TensorRT-LLM - Triton backend
|
||||
/triton_backend @NVIDIA/trt-llm-triton-backend-devs
|
||||
|
||||
# TensorRT-LLM Pytorch backend
|
||||
/tensorrt_llm/_torch @NVIDIA/trt-llm-torch-devs
|
||||
|
||||
## TensorRT-LLM Pytorch - Modules
|
||||
/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules
|
||||
|
||||
## TensorRT-LLM Pytorch Models
|
||||
/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs
|
||||
/examples/models @NVIDIA/trt-llm-torch-models-devs @NVIDIA/trt-llm-doc-owners
|
||||
|
||||
## TensorRT-LLM Pytorch backend - runtime
|
||||
/tensorrt_llm/_torch/pyexecutor @NVIDIA/trt-llm-torch-runtime-devs
|
||||
## TensorRT-LLM Pytorch backend - AutoDeploy flow
|
||||
/tensorrt_llm/_torch/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs
|
||||
/tensorrt_llm/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs
|
||||
/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs @NVIDIA/trt-llm-doc-owners
|
||||
|
||||
## TensorRT-LLM Pytorch - Speculative Decoding
|
||||
/tensorrt_llm/_torch/speculative @NVIDIA/trt-llm-torch-spec-decoding
|
||||
@ -31,12 +57,6 @@
|
||||
/tensorrt_llm/_torch/attention_backend @NVIDIA/trt-llm-torch-attention-devs
|
||||
/tensorrt_llm/_torch/modules/attention.py @NVIDIA/trt-llm-torch-attention-devs
|
||||
|
||||
## TensorRT-LLM Pytorch - Modules
|
||||
/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules
|
||||
|
||||
|
||||
## TensorRT-LLM Pytorch Models
|
||||
/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs
|
||||
|
||||
### TensorRT-LLM Pytorch - Models - Gemma
|
||||
/tensorrt_llm/_torch/models/modeling_gemma3.py @NVIDIA/trt-llm-torch-models-gemma-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
@ -87,7 +107,7 @@
|
||||
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
/tensorrt_llm/_torch/models/modeling_nemotron_h.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
/tensorrt_llm/_torch/pyexecutor/resource_manager.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-runtime-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
/tensorrt_llm/_torch/modules/mamba @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
/tests/unittest/_torch/modeling/test_modeling_nemotron.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
|
||||
@ -108,8 +128,6 @@
|
||||
/cpp/tensorrt_llm/runtime/loraUtils.cpp @NVIDIA/trt-llm-torch-peft
|
||||
/cpp/tensorrt_llm/runtime/loraUtils.h @NVIDIA/trt-llm-torch-peft
|
||||
|
||||
## TensorRT-LLM - Triton backend
|
||||
/triton_backend @NVIDIA/trt-llm-triton-backend-devs
|
||||
|
||||
## TensorRT-LLM trtllm-bench Reviewers
|
||||
/tensorrt_llm/bench @NVIDIA/trtllm-bench-reviewers
|
||||
@ -121,10 +139,9 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
||||
/tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs
|
||||
|
||||
## TensorRT-LLM LLM Disaggregated
|
||||
/examples/disaggregated @NVIDIA/trt-llm-disagg-devs
|
||||
/examples/disaggregated @NVIDIA/trt-llm-disagg-devs @NVIDIA/trt-llm-doc-owners
|
||||
/tensorrt_llm/disaggregated_params.py @NVIDIA/trt-llm-disagg-devs
|
||||
/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @NVIDIA/trt-llm-disagg-devs
|
||||
/tensorrt_llm/_torch/pyexecutor/py_executor.py @NVIDIA/trt-llm-disagg-devs
|
||||
/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @NVIDIA/trt-llm-disagg-devs
|
||||
/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @NVIDIA/trt-llm-disagg-devs
|
||||
/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @NVIDIA/trt-llm-disagg-devs
|
||||
@ -135,19 +152,6 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
||||
/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @NVIDIA/trt-llm-disagg-devs
|
||||
/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @NVIDIA/trt-llm-disagg-devs
|
||||
|
||||
## TensorRT-LLM Infra
|
||||
|
||||
### CI
|
||||
/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
### Setup
|
||||
/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
### Github workflows
|
||||
/tensorrt_llm/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
/tensorrt_llm/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs
|
||||
|
||||
## TensorRT-LLM - Docs
|
||||
/docs @NVIDIA/trt-llm-doc-owners
|
||||
/examples @NVIDIA/trt-llm-doc-owners
|
||||
|
||||
# The rule below requires that any PR modifying public APIs must be approved by at least one member
|
||||
# of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team.
|
||||
|
||||
66
.github/ISSUE_TEMPLATE/01-installation.yml
vendored
Normal file
66
.github/ISSUE_TEMPLATE/01-installation.yml
vendored
Normal file
@ -0,0 +1,66 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/200-installation.yml
|
||||
name: 🛠️ Installation
|
||||
description: Report an issue here when you hit errors during installation.
|
||||
title: "[Installation]: "
|
||||
labels: ["Installation"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please provide the following system information to help us debug your installation issue:
|
||||
|
||||
```bash
|
||||
# System information
|
||||
cat /etc/os-release
|
||||
nvidia-smi
|
||||
nvcc --version
|
||||
python --version
|
||||
pip list | grep -E "(tensorrt|torch|cuda)"
|
||||
|
||||
# TensorRT-LLM installation method and version
|
||||
pip show tensorrt_llm
|
||||
```
|
||||
value: |
|
||||
**System Information:**
|
||||
- OS:
|
||||
- Python version:
|
||||
- CUDA version:
|
||||
- GPU model(s):
|
||||
- Driver version:
|
||||
- TensorRT version:
|
||||
- PyTorch version:
|
||||
- TensorRT-LLM version:
|
||||
|
||||
**Detailed output:**
|
||||
```text
|
||||
Paste the output of the above commands here
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: How you are installing TensorRT-LLM
|
||||
description: |
|
||||
Paste the full command you are trying to execute or describe your installation method.
|
||||
value: |
|
||||
```sh
|
||||
# Installation command or method
|
||||
pip install tensorrt_llm
|
||||
```
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [installation documentation](https://nvidia.github.io/TensorRT-LLM/installation/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
41
.github/ISSUE_TEMPLATE/02-new-model.yml
vendored
Normal file
41
.github/ISSUE_TEMPLATE/02-new-model.yml
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/600-new-model.yml
|
||||
name: 🤗 Support request for a new model from huggingface
|
||||
description: Submit a proposal/request for a new model from huggingface
|
||||
title: "[New Model]: "
|
||||
labels: ["new model"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
|
||||
#### We also highly recommend you read https://nvidia.github.io/TensorRT-LLM/architecture/add-model.html first to understand how to add a new model.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The model to consider.
|
||||
description: >
|
||||
A huggingface identifier, pointing to the model, e.g. `meta-llama/Llama-3.1-8B-Instruct` .
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The closest model TensorRT-LLM already supports.
|
||||
description: >
|
||||
Here is the list of models already supported by TensorRT-LLM: https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/models (TRT backend) and https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/_torch/models (Pytorch backend) . Which model is the most similar to the model you want to add support for?
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: What's your difficulty of supporting the model you want?
|
||||
description: >
|
||||
For example, any new operators or new architecture?
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
31
.github/ISSUE_TEMPLATE/03-documentation.yml
vendored
Normal file
31
.github/ISSUE_TEMPLATE/03-documentation.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/100-documentation.yml
|
||||
name: 📚 Documentation
|
||||
description: Report an issue related to https://nvidia.github.io/TensorRT-LLM/
|
||||
title: "[Doc]: "
|
||||
labels: ["Documentation"]
|
||||
assignees: ["nv-guomingz"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 📚 The doc issue
|
||||
description: >
|
||||
A clear and concise description of what content in https://nvidia.github.io/TensorRT-LLM/ is an issue.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Suggest a potential alternative/fix
|
||||
description: >
|
||||
Tell us how we could improve the documentation in this regard.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
62
.github/ISSUE_TEMPLATE/04-questions.yml
vendored
Normal file
62
.github/ISSUE_TEMPLATE/04-questions.yml
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/300-usage.yml
|
||||
name: 💻 Questions
|
||||
description: Raise an issue here if you don't know how to use TensorRT-LLM.
|
||||
title: "[Usage]: "
|
||||
labels: ["question"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please provide the following system information to help us debug your usage issue:
|
||||
|
||||
```bash
|
||||
# System information
|
||||
nvidia-smi
|
||||
python --version
|
||||
pip show tensorrt_llm
|
||||
```
|
||||
value: |
|
||||
**System Information:**
|
||||
- OS:
|
||||
- Python version:
|
||||
- CUDA version:
|
||||
- GPU model(s):
|
||||
- Driver version:
|
||||
- TensorRT-LLM version:
|
||||
|
||||
**Detailed output:**
|
||||
```text
|
||||
Paste the output of the above commands here
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: How would you like to use TensorRT-LLM
|
||||
description: |
|
||||
A detailed description of how you want to use TensorRT-LLM.
|
||||
value: |
|
||||
I want to run inference of a [specific model](put Hugging Face link here). I don't know how to integrate it with TensorRT-LLM or optimize it for my use case.
|
||||
|
||||
**Specific questions:**
|
||||
- Model:
|
||||
- Use case (e.g., chatbot, batch inference, real-time serving):
|
||||
- Expected throughput/latency requirements:
|
||||
- Multi-GPU setup needed:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
40
.github/ISSUE_TEMPLATE/05-feature-request.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/05-feature-request.yml
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/500-feature-request.yml
|
||||
name: 🚀 Feature request
|
||||
description: Submit a proposal/request for a new TensorRT-LLM feature
|
||||
title: "[Feature]: "
|
||||
labels: ["feature request"]
|
||||
assignees: ["laikhtewari"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 🚀 The feature, motivation and pitch
|
||||
description: >
|
||||
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Alternatives
|
||||
description: >
|
||||
A description of any alternative solutions or features you've considered, if any.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: >
|
||||
Add any other context or screenshots about the feature request.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
191
.github/ISSUE_TEMPLATE/06-bug-report.yml
vendored
Normal file
191
.github/ISSUE_TEMPLATE/06-bug-report.yml
vendored
Normal file
@ -0,0 +1,191 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/400-bug-report.yml
|
||||
name: "🐛 Bug Report"
|
||||
description: Submit a bug report to help us improve TensorRT-LLM
|
||||
title: "[Bug]: "
|
||||
labels: [ "bug" ]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
⚠️ **SECURITY WARNING:** Please review any text you paste to ensure it does not contain sensitive information such as:
|
||||
- API tokens or keys (e.g., Hugging Face tokens, OpenAI API keys)
|
||||
- Passwords or authentication credentials
|
||||
- Private URLs or endpoints
|
||||
- Personal or confidential data
|
||||
|
||||
Consider redacting or replacing sensitive values with placeholders like `<YOUR_TOKEN_HERE>` when sharing configuration or code examples.
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us.
|
||||
placeholder: |
|
||||
- CPU architecture (e.g., x86_64, aarch64)
|
||||
- CPU/Host memory size (if known)
|
||||
- GPU properties
|
||||
- GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S)
|
||||
- GPU memory size (if known)
|
||||
- Clock frequencies used (if applicable)
|
||||
- Libraries
|
||||
- TensorRT-LLM branch or tag (e.g., main, v0.7.1)
|
||||
- TensorRT-LLM commit (if known)
|
||||
- Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used
|
||||
- Container used (if running TensorRT-LLM in a container)
|
||||
- NVIDIA driver version
|
||||
- OS (Ubuntu 24.04, CentOS 8)
|
||||
- Any other information that may be useful in reproducing the bug
|
||||
|
||||
**Commands to gather system information:**
|
||||
```bash
|
||||
nvidia-smi
|
||||
nvcc --version
|
||||
python --version
|
||||
pip show tensorrt_llm tensorrt torch
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: who-can-help
|
||||
attributes:
|
||||
label: Who can help?
|
||||
description: |
|
||||
To expedite the response to your issue, it would be helpful if you could identify the appropriate person
|
||||
to tag using the **@** symbol. Here is a general guideline on **whom to tag**.
|
||||
|
||||
Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag,
|
||||
you can leave it blank, and a core maintainer will make sure to involve the appropriate person.
|
||||
|
||||
Please tag fewer than 3 people.
|
||||
|
||||
Quantization: @Tracin
|
||||
|
||||
Documentation: @juney-nvidia
|
||||
|
||||
Feature request: @laikhtewari
|
||||
|
||||
Performance: @kaiyux
|
||||
|
||||
placeholder: "@Username ..."
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "The official example scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: checkboxes
|
||||
id: information-tasks
|
||||
attributes:
|
||||
label: Tasks
|
||||
description: "The tasks I am working on are:"
|
||||
options:
|
||||
- label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Please provide a clear and concise description of what the bug is and how to reproduce it.
|
||||
|
||||
If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example:
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
|
||||
|
||||
Remember to use code tags to properly format your code. You can refer to the
|
||||
link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting.
|
||||
|
||||
Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code.
|
||||
It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes.
|
||||
|
||||
Please set the environment variable `export TLLM_DEBUG_MODE=1` to turn on more logging to help debugging potential issues.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
```python
|
||||
# Sample code to reproduce the problem
|
||||
```
|
||||
|
||||
```
|
||||
The error message you got, with the full traceback and the error logs.
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible."
|
||||
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: actual behavior
|
||||
description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible."
|
||||
|
||||
- type: textarea
|
||||
id: additional-notes
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: additional notes
|
||||
description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)."
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
⚠️ Please separate bugs of `transformers`, `pytorch` implementation or usage from bugs of `TensorRT-LLM`.
|
||||
|
||||
- If the error only appears in TensorRT-LLM, please provide the detailed script of how you run `TensorRT-LLM`, also highlight the difference and what you expect.
|
||||
|
||||
Thanks for reporting 🙏!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
74
.github/ISSUE_TEMPLATE/07-performance-discussion.yml
vendored
Normal file
74
.github/ISSUE_TEMPLATE/07-performance-discussion.yml
vendored
Normal file
@ -0,0 +1,74 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/700-performance-discussion.yml
|
||||
name: ⚡ Discussion on the performance of TensorRT-LLM
|
||||
description: Submit a proposal/discussion about the performance of TensorRT-LLM
|
||||
title: "[Performance]: "
|
||||
labels: ["Performance"]
|
||||
assignees: ["byshiue", "kaiyux"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Proposal to improve performance
|
||||
description: >
|
||||
How do you plan to improve TensorRT-LLM's performance?
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Report of performance regression
|
||||
description: >
|
||||
Please provide detailed description of performance comparison to confirm the regression. You may want to run the benchmark script at https://github.com/NVIDIA/TensorRT-LLM/tree/main/benchmarks .
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Misc discussion on performance
|
||||
description: >
|
||||
Anything about the performance.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your current environment (if you think it is necessary)
|
||||
description: |
|
||||
Please provide the following system information to help with performance analysis:
|
||||
|
||||
```bash
|
||||
# System information
|
||||
nvidia-smi
|
||||
nvcc --version
|
||||
python --version
|
||||
pip show tensorrt_llm tensorrt torch
|
||||
```
|
||||
value: |
|
||||
**System Information:**
|
||||
- OS:
|
||||
- Python version:
|
||||
- CUDA version:
|
||||
- GPU model(s):
|
||||
- Driver version:
|
||||
- TensorRT version:
|
||||
- PyTorch version:
|
||||
- TensorRT-LLM version:
|
||||
|
||||
**Detailed output:**
|
||||
```text
|
||||
Paste the output of the above commands here
|
||||
```
|
||||
validations:
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
58
.github/ISSUE_TEMPLATE/08-RFC.yml
vendored
Normal file
58
.github/ISSUE_TEMPLATE/08-RFC.yml
vendored
Normal file
@ -0,0 +1,58 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/750-RFC.yml
|
||||
name: 💬 Request for comments (RFC).
|
||||
description: Ask for feedback on major architectural changes or design choices.
|
||||
title: "[RFC]: "
|
||||
labels: ["RFC"]
|
||||
assignees: ["laikhtewari"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Please take a look at previous [RFCs](https://github.com/NVIDIA/TensorRT-LLM/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Motivation.
|
||||
description: >
|
||||
The motivation of the RFC.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Proposed Change.
|
||||
description: >
|
||||
The proposed change of the RFC.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Feedback Period.
|
||||
description: >
|
||||
The feedback period of the RFC. Usually at least one week.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: CC List.
|
||||
description: >
|
||||
The list of people you want to CC.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Any Other Things.
|
||||
description: >
|
||||
Any other things you would like to mention.
|
||||
validations:
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉! The TensorRT-LLM team reviews RFCs during regular team meetings. Most RFCs can be discussed online, but you can also reach out to the team through GitHub discussions or issues for additional feedback.
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions.
|
||||
required: true
|
||||
114
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
114
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -1,114 +0,0 @@
|
||||
name: "Bug Report"
|
||||
description: Submit a bug report to help us improve TensorRT-LLM
|
||||
labels: [ "bug" ]
|
||||
body:
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us.
|
||||
placeholder: |
|
||||
- CPU architecture (e.g., x86_64, aarch64)
|
||||
- CPU/Host memory size (if known)
|
||||
- GPU properties
|
||||
- GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S)
|
||||
- GPU memory size (if known)
|
||||
- Clock frequencies used (if applicable)
|
||||
- Libraries
|
||||
- TensorRT-LLM branch or tag (e.g., main, v0.7.1)
|
||||
- TensorRT-LLM commit (if known)
|
||||
- Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used
|
||||
- Container used (if running TensorRT-LLM in a container)
|
||||
- NVIDIA driver version
|
||||
- OS (Ubuntu 24.04, CentOS 8)
|
||||
- Any other information that may be useful in reproducing the bug
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: who-can-help
|
||||
attributes:
|
||||
label: Who can help?
|
||||
description: |
|
||||
To expedite the response to your issue, it would be helpful if you could identify the appropriate person
|
||||
to tag using the **@** symbol. Here is a general guideline on **whom to tag**.
|
||||
|
||||
Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag,
|
||||
you can leave it blank, and a core maintainer will make sure to involve the appropriate person.
|
||||
|
||||
Please tag fewer than 3 people.
|
||||
|
||||
Quantization: @Tracin
|
||||
|
||||
Documentation: @juney-nvidia
|
||||
|
||||
Feature request: @ncomly-nvidia
|
||||
|
||||
Performance: @kaiyux
|
||||
|
||||
placeholder: "@Username ..."
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "The official example scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: checkboxes
|
||||
id: information-tasks
|
||||
attributes:
|
||||
label: Tasks
|
||||
description: "The tasks I am working on are:"
|
||||
options:
|
||||
- label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Kindly share a code example that demonstrates the issue you encountered. It is recommending to provide a code snippet directly.
|
||||
Additionally, if you have any error messages, or stack traces related to the problem, please include them here.
|
||||
|
||||
Remember to use code tags to properly format your code. You can refer to the
|
||||
link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting.
|
||||
|
||||
Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code.
|
||||
It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible."
|
||||
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: actual behavior
|
||||
description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible."
|
||||
|
||||
- type: textarea
|
||||
id: additioanl-notes
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: additional notes
|
||||
description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)."
|
||||
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 🤔 Questions
|
||||
url: https://github.com/NVIDIA/TensorRT-LLM/discussions
|
||||
about: Ask questions and discuss with other TensorRT-LLM community members
|
||||
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
@ -18,6 +18,14 @@ Examples:
|
||||
- [https://nvbugs/1234567][fix] Fix some bugs
|
||||
- [#1234][doc] Update documentation
|
||||
- [None][chore] Minor clean-up
|
||||
|
||||
Alternative (faster) way using CodeRabbit AI:
|
||||
|
||||
**[JIRA ticket/NVBugs ID/GitHub issue/None] @coderabbitai title**
|
||||
|
||||
NOTE: "@coderabbitai title" will be replaced by the title generated by CodeRabbit AI, that includes the "[type]" and title.
|
||||
For more info, see /.coderabbit.yaml.
|
||||
|
||||
-->
|
||||
|
||||
## Description
|
||||
|
||||
@ -27,7 +27,7 @@ repos:
|
||||
args: [--allow-multiple-documents]
|
||||
exclude: ".*/gitlab/.*.yml"
|
||||
- id: trailing-whitespace
|
||||
exclude: '\.patch$'
|
||||
exclude: '\.(patch|md)$'
|
||||
- id: check-toml
|
||||
- id: mixed-line-ending
|
||||
args: [--fix=lf]
|
||||
|
||||
@ -9,7 +9,7 @@ TensorRT-LLM
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/torch/arch_overview.md) | [Performance](./docs/source/performance/perf-overview.md) | [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) | [Documentation](./docs/source/) | [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
|
||||
@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
|
||||
## Useful Links
|
||||
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM.
|
||||
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM.
|
||||
- [AutoDeploy](./examples/auto_deploy/README.md): An experimental backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
|
||||
- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
|
||||
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news.
|
||||
|
||||
@ -336,7 +336,7 @@ cd cpp/build
|
||||
`disaggServerBenchmark` only supports `decoder-only` models.
|
||||
Here is the basic usage:
|
||||
```
|
||||
export TRTLLM_USE_MPI_KVCACHE=1
|
||||
export TRTLLM_USE_UCX_KVCACHE=1
|
||||
mpirun -n ${proc} benchmarks/disaggServerBenchmark --context_engine_dirs ${context_engine_0},${context_engine_1}...,${context_engine_{m-1}} \
|
||||
--generation_engine_dirs ${generation_engine_0},${generation_engine_1}...,${generation_engine_{n-1}} --dataset ${dataset_path}
|
||||
```
|
||||
@ -344,7 +344,7 @@ This command will launch m context engines and n generation engines. You need to
|
||||
|
||||
for example:
|
||||
```
|
||||
export TRTLLM_USE_MPI_KVCACHE=1
|
||||
export TRTLLM_USE_UCX_KVCACHE=1
|
||||
mpirun -n 7 benchmarks/disaggServerBenchmark --context_engine_dirs ${llama_7b_tp2_pp1_dir},${llama_7b_tp1_pp1_dir} --generation_engine_dirs ${llama_7b_tp1_pp1_dir},${llama_7b_tp2_pp1_dir} --dataset ${dataset_path}
|
||||
|
||||
# need 6 gpus and 7 processes to launch the benchmark.
|
||||
|
||||
@ -495,6 +495,17 @@ if(ENABLE_UCX)
|
||||
if(NOT ${ucx_FOUND})
|
||||
set(ENABLE_UCX 0)
|
||||
else()
|
||||
if(DEFINED ENV{GITHUB_MIRROR} AND NOT "$ENV{GITHUB_MIRROR}" STREQUAL "")
|
||||
if(EXISTS "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake")
|
||||
file(READ "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake" FILE_CONTENTS)
|
||||
string(
|
||||
REPLACE "https://raw.githubusercontent.com/rapidsai/rapids-cmake"
|
||||
"$ENV{GITHUB_MIRROR}/rapidsai/rapids-cmake/raw/refs/heads"
|
||||
FILE_CONTENTS "${FILE_CONTENTS}")
|
||||
file(WRITE "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake" "${FILE_CONTENTS}")
|
||||
message(WARNING "Replace UCXX fetch_rapids.cmake with internal mirror")
|
||||
endif()
|
||||
endif()
|
||||
# installing ucxx via add_subdirectory results in strange cudart linking
|
||||
# error, thus using their installation script to isolate the installation
|
||||
# process until the issue is understood. And always trigger the build so
|
||||
|
||||
@ -75,27 +75,19 @@ public:
|
||||
std::vector<executor::LookaheadDecodingConfig>>
|
||||
operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
|
||||
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
|
||||
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
|
||||
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
|
||||
SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const;
|
||||
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
|
||||
CudaStream const& runtimeStream, CudaStream const& decoderStream, SizeType32 maxSequenceLength,
|
||||
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const;
|
||||
|
||||
[[nodiscard]] std::tuple<std::vector<runtime::ITensor::SharedConstPtr>,
|
||||
std::vector<executor::LookaheadDecodingConfig>>
|
||||
createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
|
||||
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
|
||||
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
|
||||
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
|
||||
nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
|
||||
runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream,
|
||||
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;
|
||||
|
||||
private:
|
||||
//! @brief Initialize the decoder at `batchSlot` with a new `request`. Exposed only for static batching via
|
||||
//! GptDecoderBatched::newBatch()
|
||||
static void newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request,
|
||||
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
|
||||
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
|
||||
SizeType32 maxSequenceLength);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new speculative decoding request
|
||||
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
@ -36,7 +37,8 @@ using BlockPtr = std::shared_ptr<KVCacheBlock>;
|
||||
class KVCacheEventManager
|
||||
{
|
||||
public:
|
||||
explicit KVCacheEventManager(size_t maxKVEventEntries);
|
||||
explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank = std::nullopt,
|
||||
std::optional<SizeType32> attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5);
|
||||
|
||||
~KVCacheEventManager();
|
||||
KVCacheEventManager(KVCacheEventManager& other) = delete;
|
||||
@ -61,14 +63,19 @@ public:
|
||||
// Worker thread which adds events to mEvents.
|
||||
void worker();
|
||||
|
||||
// Thread which exchanges events if attentionDP is enabled
|
||||
void exchangeAttentionDpThread();
|
||||
|
||||
private:
|
||||
// Add an event to mEventQueue
|
||||
void enqueueEvent(executor::KVCacheEvent&& event);
|
||||
|
||||
/// @brief Flag to terminate the worker
|
||||
bool mRun;
|
||||
std::atomic<bool> mRun;
|
||||
/// @brief Worker thread
|
||||
std::thread mWorkerThread;
|
||||
/// @brief Exchange thread for attention DP events
|
||||
std::thread mExchangeAttentionDpThread;
|
||||
|
||||
/// @brief The deque of events
|
||||
std::deque<executor::KVCacheEvent> mEvents;
|
||||
@ -91,6 +98,17 @@ private:
|
||||
size_t mMaxSize;
|
||||
/// @brief An auto-incrementing event id counter
|
||||
size_t mEventId;
|
||||
|
||||
/// @brief Attention DP ranks and size
|
||||
/// If set, we will exchange KV cache events and accumulate on rank 0
|
||||
std::optional<SizeType32> mAttentionDpRank;
|
||||
std::optional<SizeType32> mAttentionDpSize;
|
||||
|
||||
/// @brief The period in milliseconds to gather attention DP events across rank
|
||||
SizeType32 mAttentionDpEventsGatherPeriodMs;
|
||||
|
||||
/// @brief MPI communicator for attention DP
|
||||
std::unique_ptr<tensorrt_llm::mpi::MpiComm> mMpiComm;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -536,8 +536,7 @@ public:
|
||||
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
|
||||
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
|
||||
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
|
||||
bool copyOnPartialReuse);
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
|
||||
|
||||
~WindowBlockManager();
|
||||
|
||||
@ -633,11 +632,6 @@ public:
|
||||
return mAllBlocksById.at(blockId);
|
||||
}
|
||||
|
||||
[[nodiscard]] BlockMapIterRange getBlocksByHash(size_t hash) const
|
||||
{
|
||||
return mContextBlocksByHash.equal_range(hash);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getTokensPerBlock() const noexcept
|
||||
{
|
||||
return mTokensPerBlock;
|
||||
@ -723,10 +717,6 @@ public:
|
||||
//! \param blockIds Id of each block.
|
||||
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||
|
||||
void addBlockToHashMap(BlockPtr const& block);
|
||||
|
||||
void removeBlockFromHashMap(BlockPtr const& block);
|
||||
|
||||
[[nodiscard]] bool verifyQueueIntegrity();
|
||||
|
||||
// Only needed when sliding window attention + paged context fmha are used together.
|
||||
@ -808,8 +798,6 @@ private:
|
||||
SizeType32 mTokensPerBlock;
|
||||
// List of all blocks by idx
|
||||
std::vector<BlockPtr> mAllBlocksById;
|
||||
// List of all context blocks by hash
|
||||
BlockMap mContextBlocksByHash;
|
||||
// Dummy block acting as root for BlockToken searches
|
||||
BlockPtr mCachedBlocksRoot;
|
||||
// KV cache type (self or cross)
|
||||
@ -841,8 +829,6 @@ private:
|
||||
double mReusedTokens;
|
||||
// Total number of input tokens
|
||||
double mTotalInputTokens;
|
||||
// Whether or not to maintain a hashmap of blocks.
|
||||
bool mEnableHashKey;
|
||||
// Whether blocks that are partially matched should be reused.
|
||||
bool mEnablePartialReuse;
|
||||
// Whether partially matched blocks that are already in use should be copied and reused.
|
||||
@ -863,8 +849,8 @@ public:
|
||||
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
|
||||
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
|
||||
bool enablePartialReuse = true, bool copyOnPartialReuse = true);
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
|
||||
bool copyOnPartialReuse = true);
|
||||
|
||||
BlockManager(BlockManager const&) = delete;
|
||||
BlockManager& operator=(BlockManager const&) = delete;
|
||||
@ -1081,11 +1067,6 @@ public:
|
||||
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
|
||||
}
|
||||
|
||||
[[nodiscard]] WindowBlockManager::BlockMapIterRange getBlocksByHash(size_t hash, SizeType32 windowSize) const
|
||||
{
|
||||
return mWindowBlockManagers.at(windowSize).getBlocksByHash(hash);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
|
||||
{
|
||||
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); });
|
||||
@ -1096,16 +1077,6 @@ public:
|
||||
return getPool(poolIdx).containsBlockScales;
|
||||
}
|
||||
|
||||
void addBlockToHashMap(BlockPtr const& block, SizeType32 windowSize)
|
||||
{
|
||||
mWindowBlockManagers.at(windowSize).addBlockToHashMap(block);
|
||||
}
|
||||
|
||||
void removeBlockFromHashMap(BlockPtr const& block, SizeType32 windowSize)
|
||||
{
|
||||
mWindowBlockManagers.at(windowSize).removeBlockFromHashMap(block);
|
||||
}
|
||||
|
||||
//! \brief Store context blocks
|
||||
void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest);
|
||||
|
||||
@ -1385,8 +1356,8 @@ public:
|
||||
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
|
||||
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
|
||||
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
|
||||
bool copyOnpartialReuse = true);
|
||||
|
||||
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
|
||||
@ -1405,8 +1376,8 @@ public:
|
||||
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
|
||||
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
|
||||
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
|
||||
bool copyOnpartialReuse = true);
|
||||
|
||||
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
|
||||
@ -1692,8 +1663,6 @@ private:
|
||||
std::unordered_map<LlmRequest::RequestIdType, GenerationRequest> mSequences;
|
||||
// Whether to cache KV pages for reuse
|
||||
bool mEnableBlockReuse;
|
||||
// Whether enable finding blocks by their hash, ignored when reuse enabled
|
||||
bool mEnableHashKey;
|
||||
// Mutex to protect access to mSequences
|
||||
mutable std::mutex mSequencesMtx;
|
||||
// buffers for static tensors, will be created after allocating pools
|
||||
|
||||
@ -828,8 +828,10 @@ public:
|
||||
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
|
||||
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
|
||||
: LlmRequestState::kCONTEXT_INIT;
|
||||
mContextCurrentPosition = 0;
|
||||
mPrepopulatedPromptLen = 0;
|
||||
mContextCurrentPositionTarget = 0;
|
||||
mContextCurrentPositionDraft = 0;
|
||||
mPrepopulatedPromptLenTarget = 0;
|
||||
mPrepopulatedPromptLenDraft = 0;
|
||||
mContextChunkSize = mPromptLen;
|
||||
mSeqSlot.reset();
|
||||
}
|
||||
@ -1049,7 +1051,7 @@ public:
|
||||
|
||||
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
|
||||
{
|
||||
return mPrepopulatedPromptLen;
|
||||
return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
|
||||
}
|
||||
|
||||
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
|
||||
@ -1066,7 +1068,10 @@ public:
|
||||
"Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen,
|
||||
promptLen, mRequestId);
|
||||
TLLM_CHECK(prepopulatedPromptLen < promptLen);
|
||||
mPrepopulatedPromptLen = prepopulatedPromptLen;
|
||||
|
||||
auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
|
||||
auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
|
||||
prePromptLen = prepopulatedPromptLen;
|
||||
|
||||
if (prepopulatedPromptLen > 0)
|
||||
{
|
||||
@ -1081,7 +1086,7 @@ public:
|
||||
chunkSize = flooredEndPosition - prepopulatedPromptLen;
|
||||
TLLM_CHECK(chunkSize <= getContextChunkSize());
|
||||
}
|
||||
setContextCurrentPosition(prepopulatedPromptLen);
|
||||
contextCurrentPosition = prepopulatedPromptLen;
|
||||
setContextChunkSize(chunkSize);
|
||||
|
||||
if (!isLastContextChunk())
|
||||
@ -1522,14 +1527,15 @@ public:
|
||||
|
||||
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
|
||||
{
|
||||
mContextCurrentPosition = contextCurrentPosition;
|
||||
mContextCurrentPositionDraft = contextCurrentPosition;
|
||||
mContextCurrentPositionTarget = contextCurrentPosition;
|
||||
}
|
||||
|
||||
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
|
||||
/// or end of the context is returned.
|
||||
[[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept
|
||||
{
|
||||
return mContextCurrentPosition;
|
||||
return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
|
||||
}
|
||||
|
||||
/// Return the length of the context that has not yet been processed.
|
||||
@ -1570,14 +1576,16 @@ public:
|
||||
{
|
||||
// The number of cached token is encountered in mContextCurrentPosition,
|
||||
// so the start position of the context is mPrepopulatedPromptLen.
|
||||
return mContextCurrentPosition == mPrepopulatedPromptLen;
|
||||
return getContextCurrentPosition() == getPrepopulatedPromptLen();
|
||||
}
|
||||
|
||||
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
|
||||
void moveToNextContextChunk()
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
|
||||
mContextCurrentPosition += getContextChunkSize();
|
||||
|
||||
mContextCurrentPositionDraft += getContextChunkSize();
|
||||
mContextCurrentPositionTarget += getContextChunkSize();
|
||||
setContextChunkSize(0);
|
||||
}
|
||||
|
||||
@ -1843,6 +1851,16 @@ public:
|
||||
return mIsDummyRequest;
|
||||
}
|
||||
|
||||
void setUseDraftModel(bool useDraftModel)
|
||||
{
|
||||
mUseDraftModel = useDraftModel;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool useDraftModel() const
|
||||
{
|
||||
return mUseDraftModel;
|
||||
}
|
||||
|
||||
RequestIdType mRequestId;
|
||||
SizeType32 mPromptLen;
|
||||
SizeType32 mMaxNewTokens;
|
||||
@ -1885,7 +1903,8 @@ protected:
|
||||
// Number of tokens already in KV cache before context phase.
|
||||
// A value > 0 indicates cached KV cache blocks were reused.
|
||||
// Up to inputLen - 1 tokens can be reused.
|
||||
SizeType32 mPrepopulatedPromptLen{0};
|
||||
SizeType32 mPrepopulatedPromptLenTarget{0};
|
||||
SizeType32 mPrepopulatedPromptLenDraft{0};
|
||||
|
||||
SizeType32 mMaxSentTokenLen;
|
||||
|
||||
@ -1916,7 +1935,8 @@ protected:
|
||||
// The size of the context chunk must be multiple of the KV-Cache block size except the last one.
|
||||
// Value `0` means Chunked-Context is disabled.
|
||||
SizeType32 mContextChunkSize{0};
|
||||
SizeType32 mContextCurrentPosition{0};
|
||||
SizeType32 mContextCurrentPositionTarget{0};
|
||||
SizeType32 mContextCurrentPositionDraft{0};
|
||||
|
||||
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
|
||||
VecLogProbs mCumLogProbs; // [beamSize]
|
||||
@ -2017,6 +2037,8 @@ protected:
|
||||
|
||||
bool mIsDummyRequest{false};
|
||||
|
||||
bool mUseDraftModel{false};
|
||||
|
||||
private:
|
||||
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
|
||||
{
|
||||
@ -2027,7 +2049,7 @@ private:
|
||||
|
||||
// Scatter the input tokens to other beam
|
||||
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
|
||||
mLastTokens = VecTokens(mSamplingConfig.beamWidth);
|
||||
mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back());
|
||||
|
||||
// Init mUniqueTokens
|
||||
VecUniqueTokens uniqueTokens{inputTokens.size()};
|
||||
@ -2347,6 +2369,9 @@ public:
|
||||
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager);
|
||||
|
||||
void moveLoraWeightsToGpu(runtime::BufferManager const& manager);
|
||||
|
||||
// Remove LoRA weights and LoRA config tensors
|
||||
void removeLoraTensors();
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -16,25 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, char const* info)
|
||||
{
|
||||
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info).c_str());
|
||||
}
|
||||
|
||||
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
|
||||
{
|
||||
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()).c_str());
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
class DebugConfig
|
||||
{
|
||||
public:
|
||||
@ -86,12 +69,3 @@ public:
|
||||
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__).c_str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_THROW(...) \
|
||||
do \
|
||||
{ \
|
||||
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_WRAP(ex) \
|
||||
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())
|
||||
|
||||
@ -122,6 +122,16 @@ public:
|
||||
return QuantMode(BaseType(1u) << 14);
|
||||
}
|
||||
|
||||
static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 15);
|
||||
}
|
||||
|
||||
static constexpr QuantMode w4a16Mxfp4() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 16);
|
||||
}
|
||||
|
||||
constexpr BaseType value() const noexcept
|
||||
{
|
||||
return mValue;
|
||||
@ -202,6 +212,16 @@ public:
|
||||
return isSet(w4a8Mxfp4Fp8());
|
||||
}
|
||||
|
||||
constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept
|
||||
{
|
||||
return isSet(w4a8Mxfp4Mxfp8());
|
||||
}
|
||||
|
||||
constexpr bool hasW4a16Mxfp4() const noexcept
|
||||
{
|
||||
return isSet(w4a16Mxfp4());
|
||||
}
|
||||
|
||||
constexpr bool hasKvCacheQuant() const noexcept
|
||||
{
|
||||
return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();
|
||||
@ -209,7 +229,8 @@ public:
|
||||
|
||||
static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken,
|
||||
bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq,
|
||||
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8)
|
||||
bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8,
|
||||
bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4)
|
||||
{
|
||||
QuantMode quantMode{};
|
||||
if (quantizeWeights)
|
||||
@ -278,25 +299,35 @@ public:
|
||||
quantMode += w4a8Mxfp4Fp8();
|
||||
}
|
||||
|
||||
if (useW4a8Mxfp4Mxfp8)
|
||||
{
|
||||
quantMode += w4a8Mxfp4Mxfp8();
|
||||
}
|
||||
|
||||
if (useW4a16Mxfp4)
|
||||
{
|
||||
quantMode += w4a16Mxfp4();
|
||||
}
|
||||
|
||||
return quantMode;
|
||||
}
|
||||
|
||||
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
|
||||
{
|
||||
return fromDescription(
|
||||
true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false);
|
||||
return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false,
|
||||
false, false, false, false);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useQServe(bool perGroup)
|
||||
{
|
||||
return fromDescription(
|
||||
true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false);
|
||||
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false,
|
||||
false, false, false);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
|
||||
{
|
||||
return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false,
|
||||
false, false, false);
|
||||
false, false, false, false, false);
|
||||
}
|
||||
|
||||
static QuantMode const fromQuantAlgo(
|
||||
@ -353,28 +384,38 @@ public:
|
||||
}
|
||||
else if (quantAlgo == "FP8")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, true, false, false, false, false, false);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false,
|
||||
false, false, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "FP8_ROWWISE")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, true, true, false, false, false, false, false, true, false, false, false, false);
|
||||
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false,
|
||||
false, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "FP4")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, false, false, false, true, false, false);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
true, false, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "FP8_BLOCK_SCALES")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, false, false, false, false, true, false);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, true, false, false, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_MXFP4_FP8")
|
||||
{
|
||||
quantMode = fromDescription(
|
||||
false, false, false, false, false, false, false, false, false, false, false, false, false, true);
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, false, true, false, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_MXFP4_MXFP8")
|
||||
{
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, false, false, true, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A16_MXFP4")
|
||||
{
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false,
|
||||
false, false, false, false, true);
|
||||
}
|
||||
|
||||
if (kvCacheQuantAlgo == "INT8")
|
||||
|
||||
@ -74,6 +74,28 @@ void printElement(std::ostream& os, std::tuple<Args...> const& t)
|
||||
printTupleImpl(os, t, std::index_sequence_for<Args...>{});
|
||||
}
|
||||
|
||||
class va_list_guard
|
||||
{
|
||||
public:
|
||||
explicit va_list_guard(va_list& args)
|
||||
: mArgs(args)
|
||||
{
|
||||
}
|
||||
|
||||
~va_list_guard()
|
||||
{
|
||||
va_end(mArgs);
|
||||
}
|
||||
|
||||
va_list_guard(va_list_guard const&) = delete;
|
||||
va_list_guard& operator=(va_list_guard const&) = delete;
|
||||
va_list_guard(va_list_guard&&) = delete;
|
||||
va_list_guard& operator=(va_list_guard&&) = delete;
|
||||
|
||||
private:
|
||||
va_list& mArgs;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Override operator<< for any tuple
|
||||
@ -117,6 +139,8 @@ inline std::string fmtstr(char const* format, ...)
|
||||
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
va_list_guard args_guard(args);
|
||||
|
||||
fmtstr_(
|
||||
format,
|
||||
[](void* target, size_t count) -> char*
|
||||
@ -131,7 +155,6 @@ inline std::string fmtstr(char const* format, ...)
|
||||
return str->data();
|
||||
},
|
||||
&result, args);
|
||||
va_end(args);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -16,11 +16,22 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#define TLLM_THROW(...) \
|
||||
do \
|
||||
{ \
|
||||
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_WRAP(ex) \
|
||||
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())
|
||||
|
||||
#define NEW_TLLM_EXCEPTION(...) \
|
||||
tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__).c_str())
|
||||
|
||||
@ -45,4 +56,14 @@ private:
|
||||
int mNbFrames;
|
||||
};
|
||||
|
||||
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, char const* info)
|
||||
{
|
||||
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info).c_str());
|
||||
}
|
||||
|
||||
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
|
||||
{
|
||||
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()).c_str());
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -1001,6 +1001,7 @@ public:
|
||||
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
|
||||
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
|
||||
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
|
||||
SizeType32 attentionDpEventsGatherPeriodMs = 5,
|
||||
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
|
||||
|
||||
[[nodiscard]] bool getEnableBlockReuse() const;
|
||||
@ -1016,6 +1017,7 @@ public:
|
||||
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
|
||||
[[nodiscard]] size_t getEventBufferMaxSize() const;
|
||||
[[nodiscard]] bool getUseUvm() const;
|
||||
[[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const;
|
||||
|
||||
void setEnableBlockReuse(bool enableBlockReuse);
|
||||
void setEnablePartialReuse(bool enablePartialReuse);
|
||||
@ -1030,6 +1032,7 @@ public:
|
||||
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
|
||||
void setEventBufferMaxSize(size_t eventBufferMaxSize);
|
||||
void setUseUvm(bool useUvm);
|
||||
void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs);
|
||||
|
||||
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults);
|
||||
|
||||
@ -1085,6 +1088,9 @@ private:
|
||||
|
||||
/// @brief Whether to use UVM for the KV cache.
|
||||
bool mUseUvm;
|
||||
|
||||
/// @brief The period in milliseconds to gather attention DP events across ranks
|
||||
SizeType32 mAttentionDpEventsGatherPeriodMs;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the runtime perf knobs
|
||||
@ -1702,6 +1708,12 @@ struct KVCacheUpdatedData
|
||||
explicit KVCacheUpdatedData(IdType blockHash)
|
||||
: blockHash{blockHash} {};
|
||||
|
||||
explicit KVCacheUpdatedData(IdType blockHash, std::optional<KVCacheEventDiff<SizeType32>> cacheLevel,
|
||||
std::optional<KVCacheEventDiff<SizeType32>> priority)
|
||||
: blockHash{blockHash}
|
||||
, cacheLevel{cacheLevel}
|
||||
, priority{priority} {};
|
||||
|
||||
KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue)
|
||||
{
|
||||
cacheLevel = KVCacheEventDiff<SizeType32>{oldValue, newValue};
|
||||
@ -1726,8 +1738,8 @@ using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVC
|
||||
|
||||
struct KVCacheEvent
|
||||
{
|
||||
|
||||
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize);
|
||||
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize,
|
||||
std::optional<SizeType32> attentionDpRank = std::nullopt);
|
||||
|
||||
/// @brief The unique id of this event
|
||||
IdType eventId;
|
||||
@ -1735,6 +1747,8 @@ struct KVCacheEvent
|
||||
KVCacheEventData data;
|
||||
/// @brief The sliding window size
|
||||
SizeType32 windowSize;
|
||||
/// @brief The attention DP rank of the event, if applicable
|
||||
std::optional<SizeType32> attentionDpRank;
|
||||
};
|
||||
|
||||
/// @brief Exposes a limited set of KV cache manager functionalities
|
||||
|
||||
@ -302,6 +302,53 @@ public:
|
||||
[[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec(
|
||||
std::vector<char>& buffer);
|
||||
|
||||
// KVCacheEvent deque
|
||||
[[nodiscard]] static std::vector<char> serialize(std::deque<KVCacheEvent> const& kvCacheEvents);
|
||||
[[nodiscard]] static std::deque<KVCacheEvent> deserializeKVCacheEvents(std::vector<char>& buffer);
|
||||
|
||||
// KVCacheEvent
|
||||
[[nodiscard]] static size_t serializedSize(KVCacheEvent const& event);
|
||||
static void serialize(KVCacheEvent const& event, std::ostream& os);
|
||||
[[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is);
|
||||
|
||||
// KVCacheCreatedData
|
||||
[[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data);
|
||||
static void serialize(KVCacheCreatedData const& data, std::ostream& os);
|
||||
[[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is);
|
||||
|
||||
// KVCacheStoredData
|
||||
[[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data);
|
||||
static void serialize(KVCacheStoredData const& data, std::ostream& os);
|
||||
[[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is);
|
||||
|
||||
// KVCacheStoredBlockData
|
||||
[[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data);
|
||||
static void serialize(KVCacheStoredBlockData const& data, std::ostream& os);
|
||||
[[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is);
|
||||
|
||||
// KVCacheRemovedData
|
||||
[[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data);
|
||||
static void serialize(KVCacheRemovedData const& data, std::ostream& os);
|
||||
[[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is);
|
||||
|
||||
// KVCacheEventDiff
|
||||
template <typename T>
|
||||
[[nodiscard]] static size_t serializedSize(KVCacheEventDiff<T> const& data);
|
||||
template <typename T>
|
||||
static void serialize(KVCacheEventDiff<T> const& data, std::ostream& os);
|
||||
template <typename T>
|
||||
[[nodiscard]] static KVCacheEventDiff<T> deserializeKVCacheEventDiff(std::istream& is);
|
||||
|
||||
// KVCacheUpdateData
|
||||
[[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data);
|
||||
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
|
||||
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);
|
||||
|
||||
// UniqueToken
|
||||
[[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token);
|
||||
static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os);
|
||||
[[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is);
|
||||
|
||||
// String
|
||||
static std::string deserializeString(std::istream& is);
|
||||
|
||||
|
||||
@ -51,13 +51,13 @@ public:
|
||||
DecoderState();
|
||||
|
||||
//! @brief Setup buffers for the decoder excluding speculative decoding.
|
||||
void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
|
||||
void setup(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
|
||||
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType dtype,
|
||||
ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager);
|
||||
|
||||
//! @brief Setup buffers for the cache indirection.
|
||||
//! @details This is used for beam search on pipeline parallel ranks without a decoder.
|
||||
void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
|
||||
void setupCacheIndirection(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
|
||||
BufferManager const& bufferManager);
|
||||
|
||||
//! @brief Setup buffers for speculative decoding.
|
||||
@ -134,7 +134,7 @@ public:
|
||||
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
|
||||
[[nodiscard]] TensorPtr getAcceptedPackedPaths() const;
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxBatchSize() const;
|
||||
[[nodiscard]] SizeType32 getMaxNumSequences() const;
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
|
||||
|
||||
@ -173,6 +173,11 @@ public:
|
||||
//! @brief Workspace for beam search in streaming mode.
|
||||
[[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const;
|
||||
|
||||
//! @brief Set the beam width for a specific request in the batch.
|
||||
//! @param batchIdx The index of the request in the batch.
|
||||
//! @param beamWidth The beam width for the specified request.
|
||||
void setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth);
|
||||
|
||||
//! @brief Cache indirection input for beam search.
|
||||
[[nodiscard]] TensorPtr getCacheIndirectionInput() const;
|
||||
|
||||
@ -187,10 +192,10 @@ public:
|
||||
//! @param generationSteps The generation steps for all requests in the batch.
|
||||
void setGenerationSteps(std::vector<SizeType32> const& generationSteps);
|
||||
|
||||
//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
|
||||
//! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots.
|
||||
[[nodiscard]] DecodingInput& getJointDecodingInput() const;
|
||||
|
||||
//! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots.
|
||||
//! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots.
|
||||
[[nodiscard]] DecodingOutput& getJointDecodingOutput() const;
|
||||
|
||||
private:
|
||||
@ -209,13 +214,13 @@ private:
|
||||
SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
BufferManager const& bufferManager);
|
||||
|
||||
SizeType32 mMaxBatchSize{};
|
||||
SizeType32 mMaxNumSequences{};
|
||||
SizeType32 mMaxBeamWidth{};
|
||||
SizeType32 mMaxSequenceLength{};
|
||||
|
||||
//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
|
||||
//! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots.
|
||||
DecodingInputPtr mJointDecodingInput;
|
||||
//! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots.
|
||||
//! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots.
|
||||
DecodingOutputPtr mJointDecodingOutput;
|
||||
|
||||
//! @brief Workspace for beam search in streaming mode.
|
||||
|
||||
@ -71,7 +71,7 @@ public:
|
||||
= 0;
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
|
||||
size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
|
||||
BufferManager::CudaStreamPtr const& stream,
|
||||
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule = nullptr);
|
||||
};
|
||||
@ -84,7 +84,7 @@ public:
|
||||
using CudaStreamPtr = BufferManager::CudaStreamPtr;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
|
||||
GptDecoder(executor::DecodingMode const& mode, size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize,
|
||||
size_t vocabSizePadded, CudaStreamPtr const& stream,
|
||||
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);
|
||||
|
||||
@ -114,7 +114,7 @@ private:
|
||||
|
||||
SamplingConfig mSamplingConfig;
|
||||
|
||||
size_t mMaxBatchSize;
|
||||
size_t mMaxNumSequences;
|
||||
size_t mVocabSize;
|
||||
size_t mVocabSizePadded;
|
||||
|
||||
@ -122,7 +122,7 @@ private:
|
||||
};
|
||||
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
|
||||
size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
|
||||
BufferManager::CudaStreamPtr const& stream,
|
||||
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule)
|
||||
{
|
||||
@ -130,10 +130,10 @@ inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode c
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return std::make_unique<GptDecoder<float>>(
|
||||
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
|
||||
mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return std::make_unique<GptDecoder<half>>(
|
||||
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
|
||||
mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
|
||||
default:
|
||||
TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast<int>(dtype));
|
||||
return nullptr;
|
||||
|
||||
@ -47,7 +47,7 @@ public:
|
||||
|
||||
explicit GptDecoderBatched(CudaStreamPtr stream);
|
||||
|
||||
void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
||||
void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
|
||||
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override;
|
||||
|
||||
void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;
|
||||
|
||||
@ -86,7 +86,7 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
//! @brief Setup the decoder before calling `forward()`
|
||||
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
||||
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
|
||||
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
= 0;
|
||||
|
||||
|
||||
@ -31,26 +31,16 @@ public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using BufferPtr = IBuffer::SharedPtr;
|
||||
|
||||
explicit Request(TensorConstPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt,
|
||||
std::optional<SizeType32> endId = std::nullopt)
|
||||
: ids{std::move(ids)}
|
||||
, inputLen(inputLen)
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, endId{endId}
|
||||
explicit Request(SizeType32 inputLen)
|
||||
: inputLen(inputLen)
|
||||
{
|
||||
}
|
||||
|
||||
//! Mandatory parameters
|
||||
TensorConstPtr ids; // The input sequence of token ids, [inputSeqLen], on gpu
|
||||
SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps
|
||||
|
||||
// optional parameters
|
||||
std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request
|
||||
std::optional<SizeType32> endId; // end token id
|
||||
SizeType32 generatedTokensPerEngineStep{1}; //
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength] on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength] on gpu
|
||||
|
||||
//! Optional parameters for speculative decoding
|
||||
BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu
|
||||
|
||||
@ -68,6 +68,10 @@ enum class MpiTag : int
|
||||
// LogitsThread
|
||||
kSpecDecLogitsId = 129,
|
||||
kSpecDecLogitsData = 1025,
|
||||
|
||||
// KvCacheEventManager
|
||||
kKvCacheEventSize = 1026,
|
||||
kKvCacheEvent = 1027
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::mpi
|
||||
|
||||
@ -55,7 +55,7 @@ def getSMVersion():
|
||||
ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
|
||||
@pytest.mark.parametrize('flag', [
|
||||
"-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
|
||||
"-softcapping-scale-bmm1 30", "-contiguous-q-kv"
|
||||
"-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks"
|
||||
])
|
||||
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
|
||||
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
|
||||
@ -122,8 +122,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
|
||||
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
||||
shell=True,
|
||||
check=True)
|
||||
# alibi and softcapping-scale-bmm1 are mutually exclusive.
|
||||
if '-softcapping-scale-bmm1' not in flag:
|
||||
# alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks.
|
||||
if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag:
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
|
||||
shell=True,
|
||||
@ -183,6 +183,23 @@ def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
|
||||
shell=True,
|
||||
check=True)
|
||||
|
||||
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, sm90, layout: paged-kv or separate-q-k-v).
|
||||
if dtype == "-bf16" and input_layout in [
|
||||
"-paged-kv", "-separate-q-k-v"
|
||||
]:
|
||||
# padding mask
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
|
||||
{epsilon} {input_layout} -save-softmax",
|
||||
shell=True,
|
||||
check=True)
|
||||
# causal mask
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
|
||||
-causal-mask {epsilon} {input_layout} -save-softmax",
|
||||
shell=True,
|
||||
check=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
|
||||
ids=["bf16", "e4m3", "e4m3-bf16"])
|
||||
|
||||
@ -1971,8 +1971,7 @@ def selected_mask_types(kspec):
|
||||
sliding_or_chunked_causal_mask = '0'
|
||||
custom_mask = '0'
|
||||
elif (kspec.head_size, kspec.head_size_v) == (192, 128):
|
||||
# MLA context phase only needs causal mask now
|
||||
padding_mask = '0'
|
||||
# MLA context phase only needs causal mask and padding mask (for chunked prefill) now
|
||||
sliding_or_chunked_causal_mask = '0'
|
||||
custom_mask = '0'
|
||||
elif (kspec.head_size, kspec.head_size_v) == (576, 512):
|
||||
@ -2311,8 +2310,7 @@ def get_api_code(specs_names):
|
||||
# whether support alibi or not.
|
||||
if kspec.warp_specialization:
|
||||
il_check += '&& params.has_alibi ' if kspec.alibi else '&& !params.has_alibi '
|
||||
if kspec.input_layout.value == InputLayout.CONTIGUOUS_Q_KV:
|
||||
il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr '
|
||||
il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr '
|
||||
# use enable_attn_logit_softcapping or not.
|
||||
il_check += '&& enable_attn_logit_softcapping ' if kspec.enable_attn_logit_softcapping else '&& !enable_attn_logit_softcapping '
|
||||
# check sage block sizes
|
||||
@ -3653,104 +3651,110 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
||||
# alibi and enable_attn_logit_softcapping shouldn't be used together.
|
||||
if alibi and enable_attn_logit_softcapping:
|
||||
continue
|
||||
if input_layout != InputLayout.CONTIGUOUS_Q_KV and return_softmax:
|
||||
continue
|
||||
# only specify
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=[32, 40, 48, 64],
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=256,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
# for normal attention, we only need contiguous kv as input layout when returning softmax.
|
||||
skip_combination = return_softmax and (input_layout
|
||||
!= InputLayout.CONTIGUOUS_Q_KV)
|
||||
# for context mla, we need paged kv or separate qkv as input layout when returning softmax.
|
||||
skip_mla_combination = return_softmax and (
|
||||
input_layout != InputLayout.Q_PAGED_KV
|
||||
and input_layout != InputLayout.SEPARATE_Q_K_V)
|
||||
if not skip_combination:
|
||||
# only specify
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=[32, 40, 48, 64],
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=256,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=[72, 80, 96, 104, 128],
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=128,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=[72, 80, 96, 104, 128],
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=128,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=[160, 192, 256],
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=64,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=[160, 192, 256],
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=64,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
'''
|
||||
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
|
||||
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
|
||||
@ -3762,38 +3766,39 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
||||
Then for fp16/bf16 context MLA, d remains 192 (192 * 2 = 128 * 3), and dv remains 128,
|
||||
if kv_step = 128, then smem_size = 208 KB, smem is fully utilized.
|
||||
'''
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=192,
|
||||
head_size_v=128,
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=128,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
if not skip_mla_combination:
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
sm=sm,
|
||||
sm_mma=90,
|
||||
dtype=dtype,
|
||||
seq_len=0, # support any sequence length
|
||||
head_size=192,
|
||||
head_size_v=128,
|
||||
warps_m=4, #4x1 warpgroups
|
||||
warps_n=1,
|
||||
version=2,
|
||||
interleaved=False,
|
||||
ldgsts_q=
|
||||
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||
ldgsts_k=False,
|
||||
ldgsts_v=False,
|
||||
share_smem_k_v=False,
|
||||
loop_step=64,
|
||||
q_tile_buffers=1, # only used by warp specialized kernels
|
||||
has_noloop=0,
|
||||
noloop_step=64,
|
||||
kv_loop_step=128,
|
||||
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||
unroll_threshold=1,
|
||||
has_scale_max=False,
|
||||
flash_attention=True,
|
||||
warp_specialization=True,
|
||||
alibi=alibi,
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
|
||||
|
||||
# Note this will be used in TRT-LLM.
|
||||
|
||||
@ -1904,8 +1904,7 @@ struct Softmax_saver
|
||||
, softmax_sum_ptr_(reinterpret_cast<char*>(params.softmax_stats_ptr))
|
||||
, softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes)
|
||||
{
|
||||
size_t softmax_max_off = sizeof(float) * params.b * params.s * params.h;
|
||||
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr) + softmax_max_off;
|
||||
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr);
|
||||
|
||||
int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP;
|
||||
@ -1917,9 +1916,9 @@ struct Softmax_saver
|
||||
store_softmax_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0);
|
||||
|
||||
// assume fixed seq length for the batch
|
||||
size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float);
|
||||
softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes;
|
||||
size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float) * 2;
|
||||
softmax_max_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes;
|
||||
softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes + sizeof(float);
|
||||
};
|
||||
|
||||
inline __device__ void store(int q_loop, float* p_sum, float* p_max)
|
||||
@ -1938,19 +1937,19 @@ struct Softmax_saver
|
||||
int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA;
|
||||
if (row0_ + row_offset < actual_q_len_)
|
||||
{
|
||||
fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0);
|
||||
fmha::stg(softmax_max_ptr_ + row_offset * softmax_stats_stride_in_bytes_, max0);
|
||||
fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0);
|
||||
}
|
||||
if (row0_ + row_offset + 8 < actual_q_len_)
|
||||
{
|
||||
fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1);
|
||||
fmha::stg(softmax_max_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, max1);
|
||||
fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ptr
|
||||
// ptr (total_token_q, h, 2) float
|
||||
char* softmax_sum_ptr_ = nullptr;
|
||||
char* softmax_max_ptr_ = nullptr;
|
||||
|
||||
|
||||
@ -465,8 +465,7 @@ struct Softmax_saver_tma
|
||||
, softmax_sum_ptr_(reinterpret_cast<char*>(params.softmax_stats_ptr))
|
||||
, softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes)
|
||||
{
|
||||
size_t softmax_max_off = sizeof(float) * params.b * params.s * params.h;
|
||||
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr) + softmax_max_off;
|
||||
softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr);
|
||||
int warp = (threadIdx.x % 128) / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP;
|
||||
// MMA row0 index (8x4 thread layout)
|
||||
@ -474,9 +473,9 @@ struct Softmax_saver_tma
|
||||
|
||||
int sum_s = params.is_s_padded ? params.s * head_info.bidb : params.cu_q_seqlens[head_info.bidb];
|
||||
int token_id = sum_s * params.h + head_info.bidh;
|
||||
size_t const bh_offset = token_id * sizeof(float) + local_q_tile_offset_ * softmax_stats_stride_in_bytes_;
|
||||
softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_;
|
||||
size_t const bh_offset = token_id * sizeof(float) * 2 + local_q_tile_offset_ * softmax_stats_stride_in_bytes_;
|
||||
softmax_max_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_;
|
||||
softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_ + sizeof(float);
|
||||
};
|
||||
|
||||
inline __device__ void store(float* p_sum, float* p_max, float sqrt_d, int row_offset, bool valid_run)
|
||||
@ -487,7 +486,7 @@ struct Softmax_saver_tma
|
||||
int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP;
|
||||
if (lane % 4 < 2)
|
||||
{
|
||||
values = p_sum[lane % 2] == 0.f ? 1.f : 1.0f / p_sum[lane % 2];
|
||||
values = p_sum[lane % 2];
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -326,9 +326,6 @@ struct Compute
|
||||
uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]);
|
||||
Compute_tile_o ctile_o(0, smem_v);
|
||||
|
||||
// BMM2 epilogue
|
||||
Tile_o_epilogue tile_o_epilogue(params);
|
||||
|
||||
// Mutex between two compute groups.
|
||||
OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER);
|
||||
// Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions).
|
||||
@ -368,6 +365,9 @@ struct Compute
|
||||
sage_scale_row = head_info.bidb * params.h + head_info.bidh;
|
||||
}
|
||||
|
||||
// BMM2 epilogue
|
||||
Tile_o_epilogue tile_o_epilogue(params, head_info);
|
||||
|
||||
int q_step_idx = warpgroup_id;
|
||||
|
||||
// Compute work.
|
||||
@ -490,7 +490,7 @@ struct Compute
|
||||
if (valid_run)
|
||||
{
|
||||
// Final step's update.
|
||||
tile_o_epilogue.scale(ctile_o, p_sum);
|
||||
tile_o_epilogue.scale(ctile_o, p_max, p_sum);
|
||||
// Store o_tile to gmem.
|
||||
gmem_o.store(ctile_o.acc_);
|
||||
}
|
||||
|
||||
@ -500,7 +500,7 @@ struct DMA
|
||||
int const num_valid_kv_blocks = (actual_kv_seqlen + params.paged_kv_cache.mTokensPerBlock - 1)
|
||||
>> params.paged_kv_cache.mTokensPerBlockLog2;
|
||||
|
||||
for (int q_step_idx = 0; q_step_idx < q_steps; q_step_idx += 2)
|
||||
for (int q_step_idx = 0; q_step_idx < q_steps && actual_kv_seqlen > 0; q_step_idx += 2)
|
||||
{
|
||||
load_q(bidh, q_step_idx * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[0], cbw0);
|
||||
load_q(bidh, (q_step_idx + 1) * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[1], cbw1);
|
||||
|
||||
@ -454,7 +454,7 @@ struct Softmax_base
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
uint32_t const scale = float_to_half2(correction_[mi]);
|
||||
const uint32_t scale = float_to_half2(correction_[mi]);
|
||||
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
// MMAS_N > 1 when N dimension is split.
|
||||
@ -477,9 +477,15 @@ struct Softmax_base
|
||||
}
|
||||
|
||||
// BMM1 scale.
|
||||
uint32_t const scale_bmm1_;
|
||||
const uint32_t scale_bmm1_;
|
||||
// BMM1 softcapping scale.
|
||||
float const softcapping_scale_bmm1_;
|
||||
|
||||
// The sliding window size.
|
||||
int const sliding_window_size_;
|
||||
// The log2 attention chunk size.
|
||||
int const log2_chunked_attention_size_;
|
||||
|
||||
// The thread idx in the warp group.
|
||||
int tidx_;
|
||||
// The col index for the mma thread layout.
|
||||
@ -487,15 +493,10 @@ struct Softmax_base
|
||||
// The row index for the mma thread layout.
|
||||
int quad_row_;
|
||||
|
||||
// The sliding window size.
|
||||
int const sliding_window_size_;
|
||||
// The log2 attention chunk size.
|
||||
int const log2_chunked_attention_size_;
|
||||
|
||||
// The packed mask ptr.
|
||||
uint32_t const* packed_mask_ptr_;
|
||||
// The packed mask k-dim stride in bytes;
|
||||
int64_t const params_packed_mask_stride_in_bytes_;
|
||||
const int64_t params_packed_mask_stride_in_bytes_;
|
||||
|
||||
// Unpacked BMM1 output buffer.
|
||||
float elt_[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2];
|
||||
@ -1072,20 +1073,53 @@ struct Tile_o_epilogue_base
|
||||
// The MMA tile for the BMM2.
|
||||
using Mma_tile_o = typename Kernel_traits::Mma_tile_o;
|
||||
|
||||
template <typename Params>
|
||||
inline __device__ Tile_o_epilogue_base(Params const& params)
|
||||
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
|
||||
enum
|
||||
{
|
||||
; // nothing to construct.
|
||||
EXP2F_OPTIMIZATION = Kernel_traits::EXP2F_OPTIMIZATION
|
||||
};
|
||||
|
||||
template <typename Params, typename Block_info>
|
||||
inline __device__ Tile_o_epilogue_base(Params const& params, Block_info& block_info)
|
||||
{
|
||||
has_attention_sink_ = params.attention_sinks != nullptr;
|
||||
head_idx_ = block_info.bidh;
|
||||
attention_sink_ = has_attention_sink_ ? params.attention_sinks[block_info.bidh] : 0.f;
|
||||
// It is only need when the exp2f optimization is enabled, so params.scale_bmm1 is always float.
|
||||
scale_bmm1_f_ = reinterpret_cast<float const&>(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1);
|
||||
};
|
||||
|
||||
// The attention sinks.
|
||||
inline __device__ void add_attention_sink(float& sum, float max)
|
||||
{
|
||||
if (has_attention_sink_)
|
||||
{
|
||||
// The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled.
|
||||
if constexpr (EXP2F_OPTIMIZATION)
|
||||
{
|
||||
sum += exp2f(attention_sink_ * M_LOG2E - max * scale_bmm1_f_);
|
||||
}
|
||||
else
|
||||
{
|
||||
sum += expf(attention_sink_ - max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale ctile_o output by 1/sum
|
||||
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
inline __device__ void scale(
|
||||
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
{
|
||||
// Final step's update.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi];
|
||||
// The global sum.
|
||||
float global_sum_mi = global_sum[mi];
|
||||
// Add the attention sink to the global sum.
|
||||
add_attention_sink(global_sum_mi, global_max[mi]);
|
||||
// The scale.
|
||||
float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi;
|
||||
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
#pragma unroll
|
||||
@ -1096,12 +1130,21 @@ struct Tile_o_epilogue_base
|
||||
{
|
||||
float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi);
|
||||
float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1);
|
||||
reg0 *= global_sum[mi];
|
||||
reg1 *= global_sum[mi];
|
||||
reg0 *= scale;
|
||||
reg1 *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Whether the attention sink is enabled.
|
||||
bool has_attention_sink_ = false;
|
||||
// The attention sink value.
|
||||
float attention_sink_ = 0.f;
|
||||
// The float scale of bmm1 outputs.
|
||||
float scale_bmm1_f_ = 1.f;
|
||||
// The head idx.
|
||||
int head_idx_ = 0;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1138,14 +1181,21 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits>
|
||||
using Base::Tile_o_epilogue_base;
|
||||
|
||||
// Scale ctile_o output by 1/sum
|
||||
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
inline __device__ void scale(
|
||||
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
{
|
||||
// Final step's update.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi];
|
||||
uint32_t const scale = float_to_half2(global_sum[mi]);
|
||||
// The global sum.
|
||||
float global_sum_mi = global_sum[mi];
|
||||
// Add the attention sink to the global sum.
|
||||
this->add_attention_sink(global_sum_mi, global_max[mi]);
|
||||
// The scale.
|
||||
float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi;
|
||||
// The scale.
|
||||
const uint32_t scale_h = float_to_half2(scale);
|
||||
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
#pragma unroll
|
||||
@ -1155,7 +1205,7 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits>
|
||||
for (int ni = 0; ni < Mma_tile_o::CORES_N; ni++)
|
||||
{
|
||||
uint32_t& reg = ctile_o.acc_[0][mma_ni].reg(ni * Mma_tile_o::CORES_M + mi);
|
||||
reg = hmul2(reg, scale);
|
||||
reg = hmul2(reg, scale_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1215,27 +1265,58 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
// The MMA tile for the BMM2.
|
||||
using Mma_tile_o = typename Base::Mma_tile_o;
|
||||
|
||||
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
|
||||
enum
|
||||
{
|
||||
EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION
|
||||
};
|
||||
|
||||
// Ctor.
|
||||
template <typename Params>
|
||||
inline __device__ Tile_o_epilogue(Params const& params)
|
||||
: Base(params)
|
||||
template <typename Params, typename Block_info>
|
||||
inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info)
|
||||
: Base(params, block_info)
|
||||
, scale_bmm2_(*params.scale_bmm2_d)
|
||||
{
|
||||
}
|
||||
|
||||
// Add the attention sink to the global sum.
|
||||
inline __device__ void add_attention_sink(float& sum, float max)
|
||||
{
|
||||
if (this->has_attention_sink_)
|
||||
{
|
||||
// The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled.
|
||||
// Take the log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum.
|
||||
float quant_scale_in_log2 = log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE);
|
||||
if constexpr (EXP2F_OPTIMIZATION)
|
||||
{
|
||||
sum += exp2f(this->attention_sink_ * M_LOG2E - max * this->scale_bmm1_f_ + quant_scale_in_log2);
|
||||
}
|
||||
else
|
||||
{
|
||||
sum += expf(this->attention_sink_ - max + quant_scale_in_log2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale ctile_o output by 1/sum
|
||||
inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
inline __device__ void scale(
|
||||
Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M])
|
||||
{
|
||||
// Final step's update.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++)
|
||||
{
|
||||
// The global sum.
|
||||
float global_sum_mi = global_sum[mi];
|
||||
// Add the attention sink to the global sum.
|
||||
add_attention_sink(global_sum_mi, global_max[mi]);
|
||||
#ifdef UNIFIED_EPILOGUE_SCALE
|
||||
// Descaling factor
|
||||
float const scale_bmm2_f_ = reinterpret_cast<float&>(scale_bmm2_);
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum[mi];
|
||||
// The scale.
|
||||
float scale = global_sum_mi == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum_mi;
|
||||
#else
|
||||
global_sum[mi] = global_sum[mi] == 0.f ? 1.0f : 1.0f / global_sum[mi];
|
||||
float scale = global_sum_mi == 0.f ? 1.0f : 1.0f / global_sum_mi;
|
||||
#endif
|
||||
// Assume only N has multiple MMAs (MMAS_M = 1).
|
||||
#pragma unroll
|
||||
@ -1246,8 +1327,8 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
{
|
||||
float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi);
|
||||
float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1);
|
||||
reg0 *= global_sum[mi];
|
||||
reg1 *= global_sum[mi];
|
||||
reg0 *= scale;
|
||||
reg1 *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,33 +29,36 @@ using Kv_block_array = fmha::Kv_block_array;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -81,11 +84,11 @@ void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type,
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type,
|
||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
||||
void* qkv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, size_t const b, size_t const s, size_t const h, size_t const d, size_t const dv,
|
||||
int const runs, int const warps_m, int const warps_n, bool const has_alibi)
|
||||
void* qkv_d, void* vt_d, void* mask_d, void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d,
|
||||
void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, const size_t h, const size_t d,
|
||||
const size_t dv, int const runs, int const warps_m, int const warps_n, bool const has_alibi)
|
||||
{
|
||||
|
||||
cudaStream_t stream = 0;
|
||||
@ -106,28 +109,28 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty
|
||||
// Softmax.
|
||||
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
|
||||
{
|
||||
run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_bf16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_softmax,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
|
||||
{
|
||||
run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, scale_softmax,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -179,7 +182,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params,
|
||||
// types
|
||||
Data_type data_type, Data_type acc_type,
|
||||
// sizes
|
||||
size_t const b, size_t const s, size_t const h, size_t const d, size_t const packed_mask_stride,
|
||||
const size_t b, const size_t s, const size_t h, const size_t d, const size_t packed_mask_stride,
|
||||
// device pointers
|
||||
void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d,
|
||||
// scale factors
|
||||
@ -235,17 +238,17 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline void set_params(bert::Fused_multihead_attention_params_v2& params, Launch_params const launch_params,
|
||||
static inline void set_params(bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params,
|
||||
// types
|
||||
Data_type data_type, Data_type acc_type, Data_type output_dtype,
|
||||
// attention input layout
|
||||
Attention_input_layout input_layout,
|
||||
// sizes
|
||||
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const h_kv, size_t const d,
|
||||
size_t const dv, size_t const total, const size_t num_grouped_heads, const size_t sliding_window_size,
|
||||
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d,
|
||||
const size_t dv, const size_t total, const size_t num_grouped_heads, const size_t sliding_window_size,
|
||||
const size_t chunked_attention_size,
|
||||
// paged kv cache block size.
|
||||
size_t const tokens_per_block,
|
||||
const size_t tokens_per_block,
|
||||
// device pointers
|
||||
void* qkv_packed_d,
|
||||
// contiguous q.
|
||||
@ -261,8 +264,10 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
// offsets for different blocks in terms of the start address.
|
||||
int32_t* paged_block_offsets,
|
||||
// mask input.
|
||||
void* packed_mask_d, void* cu_mask_rows_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d,
|
||||
void* s_d, void* softmax_stats_d, void* scale_bmm2_d,
|
||||
void* packed_mask_d, void* cu_mask_rows_d,
|
||||
// attention sinks.
|
||||
void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, void* s_d,
|
||||
void* softmax_stats_d, void* scale_bmm2_d,
|
||||
// scale factors
|
||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
||||
// flags
|
||||
@ -329,6 +334,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
// The N dimension has to be aligned.
|
||||
params.packed_mask_stride_in_bytes = (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8;
|
||||
|
||||
// Attention sinks.
|
||||
params.attention_sinks = reinterpret_cast<float*>(attention_sinks_d);
|
||||
|
||||
#if defined(STORE_P)
|
||||
params.p_ptr = p_d;
|
||||
params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type);
|
||||
@ -340,7 +348,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
#endif // defined(STORE_S)
|
||||
|
||||
params.softmax_stats_ptr = softmax_stats_d;
|
||||
params.softmax_stats_stride_in_bytes = get_size_in_bytes(h, DATA_TYPE_FP32);
|
||||
params.softmax_stats_stride_in_bytes = get_size_in_bytes(h * 2, DATA_TYPE_FP32);
|
||||
|
||||
// Set the dimensions.
|
||||
params.b = b;
|
||||
@ -412,13 +420,13 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, size_t const s,
|
||||
size_t const d, Attention_mask_type const attention_mask_type, Attention_input_layout const input_layout,
|
||||
static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, const size_t s,
|
||||
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
|
||||
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
|
||||
bool const force_non_flash_attention, bool const force_non_warp_specialization,
|
||||
bool const force_non_granular_tiling, bool const force_fp32_acc,
|
||||
// device props
|
||||
cudaDeviceProp const props)
|
||||
const cudaDeviceProp props)
|
||||
{
|
||||
|
||||
// Set launch params to choose kernels
|
||||
@ -573,6 +581,9 @@ int main(int argc, char** argv)
|
||||
// SageAttention block sizes
|
||||
int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0;
|
||||
|
||||
// Use attention sinks (added to the denominator of softmax)
|
||||
bool use_attention_sinks = false;
|
||||
|
||||
// Read the parameters from the command-line.
|
||||
for (int ii = 1; ii < argc; ++ii)
|
||||
{
|
||||
@ -865,21 +876,28 @@ int main(int argc, char** argv)
|
||||
{
|
||||
sage_block_size_v = strtol(argv[ii], nullptr, 10);
|
||||
}
|
||||
else if (!strcmp(argv[ii], "-use-attention-sinks"))
|
||||
{
|
||||
use_attention_sinks = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (save_softmax == true)
|
||||
{
|
||||
if (input_layout != Attention_input_layout::CONTIGUOUS_Q_KV)
|
||||
bool is_MLA = (d == 192 && dv == 128);
|
||||
if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV)
|
||||
|| (is_MLA && input_layout != Attention_input_layout::Q_PAGED_KV
|
||||
&& input_layout != Attention_input_layout::SEPARATE_Q_K_V))
|
||||
{
|
||||
input_layout = Attention_input_layout::CONTIGUOUS_Q_KV;
|
||||
printf(
|
||||
"Only '--contiguous-q-kv' layout supports '-save-softmax', switched to "
|
||||
"contiguous-q-kv\n");
|
||||
fprintf(stderr,
|
||||
"For normal attention, Only '--contiguous-q-kv' layout supports "
|
||||
"'-save-softmax'. For MLA only '-paged-kv' and '-separate-q-k-v' layout supports "
|
||||
"'-save-softmax'.\n");
|
||||
exit(1);
|
||||
}
|
||||
if (data_type == DATA_TYPE_E4M3)
|
||||
{
|
||||
@ -1043,11 +1061,11 @@ int main(int argc, char** argv)
|
||||
force_non_granular_tiling, force_fp32_acc, props);
|
||||
|
||||
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
|
||||
size_t const qkv_size = s * b * h * (2 * d + dv);
|
||||
const size_t qkv_size = s * b * h * (2 * d + dv);
|
||||
// Allocate on the host.
|
||||
float* qkv_h = (float*) malloc(qkv_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type);
|
||||
const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type);
|
||||
// Allocate on the device.
|
||||
void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes));
|
||||
@ -1057,7 +1075,7 @@ int main(int argc, char** argv)
|
||||
// The shape is [B, 2, S, H, D].
|
||||
const size_t kv_size = b * s * h_kv * (d + dv);
|
||||
// The size in bytes.
|
||||
size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
// Allocate on the host.
|
||||
void* contiguous_kv_h = malloc(kv_size_in_bytes);
|
||||
// Memset the buffer.
|
||||
@ -1071,13 +1089,13 @@ int main(int argc, char** argv)
|
||||
void** kv_cache_ptrs_h = nullptr;
|
||||
void* kv_cache_pool_ptr = nullptr;
|
||||
int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr;
|
||||
size_t const max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block;
|
||||
size_t const num_total_blocks = b * 2 * max_blocks_per_seq;
|
||||
const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block;
|
||||
const size_t num_total_blocks = b * 2 * max_blocks_per_seq;
|
||||
kv_cache_ptrs_h = (void**) malloc(num_total_blocks * sizeof(void*));
|
||||
kv_cache_block_offsets_h = (int32_t*) malloc(num_total_blocks * sizeof(int32_t));
|
||||
size_t const paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type);
|
||||
const size_t paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type);
|
||||
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t)));
|
||||
size_t const kv_cache_pool_sz
|
||||
const size_t kv_cache_pool_sz
|
||||
= get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type);
|
||||
FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_pool_ptr), kv_cache_pool_sz));
|
||||
size_t ptr_index = 0;
|
||||
@ -1104,7 +1122,7 @@ int main(int argc, char** argv)
|
||||
|
||||
// Q will always be [B, S, H, Dh] with paged kv cache.
|
||||
void* q_d;
|
||||
size_t const q_size = s * b * h * d;
|
||||
const size_t q_size = s * b * h * d;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type)));
|
||||
|
||||
// K has [B, S, H_kv, D] with separate kv cache.
|
||||
@ -1122,11 +1140,11 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t)));
|
||||
|
||||
// The mask for dropout or any mask patterns.
|
||||
size_t const mask_size = s * b * s;
|
||||
const size_t mask_size = s * b * s;
|
||||
// Allocate on the host.
|
||||
float* mask_h = (float*) malloc(mask_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
// Allocate on the device.
|
||||
void* mask_d = nullptr;
|
||||
if (!skip_checks)
|
||||
@ -1158,7 +1176,7 @@ int main(int argc, char** argv)
|
||||
v1 ? 1 : 2);
|
||||
|
||||
// The number of threads per CTA.
|
||||
size_t const threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
// The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension.
|
||||
size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m);
|
||||
// The number of mmas in the N dimension.
|
||||
@ -1182,7 +1200,7 @@ int main(int argc, char** argv)
|
||||
packed_mask_size = b * mmas_m * mmas_n * threads_per_cta;
|
||||
}
|
||||
// The size in bytes.
|
||||
size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
// Allocate on the host.
|
||||
uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes);
|
||||
// Set it to 0 (indicates that all elements are valid).
|
||||
@ -1190,23 +1208,41 @@ int main(int argc, char** argv)
|
||||
// Allocate on the device.
|
||||
void* packed_mask_d = nullptr;
|
||||
|
||||
// The size of the attention sinks.
|
||||
const size_t attention_sinks_size_in_bytes = h * sizeof(float);
|
||||
|
||||
// The attention sinks.
|
||||
void* attention_sinks_d = nullptr;
|
||||
if (use_attention_sinks)
|
||||
{
|
||||
// Allocate on the host.
|
||||
float* attention_sinks_h = (float*) malloc(attention_sinks_size_in_bytes);
|
||||
// Randomly initialize the attention sinks.
|
||||
random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose);
|
||||
// Allocate on the device.
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes));
|
||||
// Copy from the host to the device.
|
||||
FMHA_CHECK_CUDA(
|
||||
cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
// The O matrix is packed as S * B * H * D.
|
||||
size_t const o_size = s * b * h * dv;
|
||||
const size_t o_size = s * b * h * dv;
|
||||
// Allocate on the host.
|
||||
float* o_h = (float*) malloc(o_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* o_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes));
|
||||
|
||||
// The softmax_stats_d vector is used to store the sum/max of the softmax per token
|
||||
// The softmax_stats_d vector is used to store the max/sum of the softmax per token
|
||||
void* softmax_stats_d;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&softmax_stats_d, 2 * sizeof(float) * b * s * h));
|
||||
FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h));
|
||||
|
||||
// The size in bytes.
|
||||
size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* tmp_d = nullptr;
|
||||
if (data_type != acc_type)
|
||||
@ -1216,13 +1252,13 @@ int main(int argc, char** argv)
|
||||
|
||||
// Allocate the reference on the host.
|
||||
float* o_ref_h = (float*) malloc(o_size * sizeof(float));
|
||||
float* softmax_sum_ref_h = (float*) malloc(b * s * h * sizeof(float));
|
||||
float* softmax_sum_h = (float*) malloc(b * s * h * sizeof(float));
|
||||
float* softmax_stats_ref_h = (float*) malloc(2 * b * s * h * sizeof(float));
|
||||
float* softmax_stats_h = (float*) malloc(2 * b * s * h * sizeof(float));
|
||||
|
||||
// The P matrix is stored as one big matrix of size S x B x H x S.
|
||||
size_t const p_size = s * b * h * s;
|
||||
const size_t p_size = s * b * h * s;
|
||||
// The size in bytes.
|
||||
size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* p_d = nullptr;
|
||||
if (!skip_checks)
|
||||
@ -1238,7 +1274,7 @@ int main(int argc, char** argv)
|
||||
#endif // defined(STORE_P)
|
||||
|
||||
// The size in bytes of the S matrix (the data type may be different from P for int8).
|
||||
size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* s_d = nullptr;
|
||||
if (!skip_checks)
|
||||
@ -1327,7 +1363,7 @@ int main(int argc, char** argv)
|
||||
|
||||
std::vector<uint32_t> seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s
|
||||
std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(),
|
||||
[=](uint32_t const)
|
||||
[=](const uint32_t)
|
||||
{
|
||||
if (fix_s)
|
||||
{
|
||||
@ -1415,7 +1451,7 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes));
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes));
|
||||
|
||||
size_t const o_packed_size = cu_seqlens.back() * h * dv;
|
||||
const size_t o_packed_size = cu_seqlens.back() * h * dv;
|
||||
// Allocate on the host.
|
||||
float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float));
|
||||
void* o_packed_d = nullptr;
|
||||
@ -1676,9 +1712,9 @@ int main(int argc, char** argv)
|
||||
total, num_grouped_heads, sliding_window_size, chunked_attention_size,
|
||||
// Paged kv cache.
|
||||
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
|
||||
packed_mask_d, cu_mask_rows_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, softmax_stats_ptr,
|
||||
scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, use_int8_scale_max, interleaved,
|
||||
is_s_padded, has_alibi);
|
||||
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
|
||||
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
||||
use_int8_scale_max, interleaved, is_s_padded, has_alibi);
|
||||
|
||||
// total number of tokens is needed to set TMA desc on the host.
|
||||
launch_params.total_q_seqlen = q_seqlens[b];
|
||||
@ -1894,8 +1930,8 @@ int main(int argc, char** argv)
|
||||
ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
||||
qkv_sbh3d_d,
|
||||
vt_d, // WAR pass in V'
|
||||
mask_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, warps_m, warps_n,
|
||||
has_alibi);
|
||||
mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs,
|
||||
warps_m, warps_n, has_alibi);
|
||||
timer.stop();
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
FMHA_CHECK_CUDA(cudaDeviceSynchronize());
|
||||
@ -1911,7 +1947,7 @@ int main(int argc, char** argv)
|
||||
|
||||
// Read the results.
|
||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_ref_h, o_d, o_size, data_type));
|
||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_sum_ref_h, softmax_stats_d, b * s * h, DATA_TYPE_FP32));
|
||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_ref_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
|
||||
}
|
||||
|
||||
// Fill-in p/s/o with garbage data.
|
||||
@ -1997,7 +2033,7 @@ int main(int argc, char** argv)
|
||||
std::vector<float> o_ref_trans_h(o_size);
|
||||
|
||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_h, o_d_view, o_view_size, output_dtype));
|
||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_sum_h, softmax_stats_d, b * s * h, DATA_TYPE_FP32));
|
||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
|
||||
|
||||
if (interleaved)
|
||||
{
|
||||
@ -2009,7 +2045,6 @@ int main(int argc, char** argv)
|
||||
// Extract the last s_q tokens from the output.
|
||||
extract_and_transpose_output<float>(
|
||||
o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, b, h, dv, is_s_padded);
|
||||
|
||||
if (verbose)
|
||||
{
|
||||
printf("\nChecking .....: O = V * S\n");
|
||||
@ -2018,8 +2053,8 @@ int main(int argc, char** argv)
|
||||
dv, epsilon, verbose, true);
|
||||
if (save_softmax)
|
||||
{
|
||||
int errors = check_softmax_results(softmax_sum_h, softmax_sum_ref_h, b, s, h, seqlens, cu_seqlens);
|
||||
status = status | (errors > 0);
|
||||
auto errors = check_softmax_results(softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, cu_seqlens);
|
||||
status = status | ((errors.first + errors.second) > 0);
|
||||
}
|
||||
}
|
||||
if (status != 0)
|
||||
@ -2114,8 +2149,8 @@ int main(int argc, char** argv)
|
||||
free(s_h);
|
||||
free(o_h);
|
||||
free(o_ref_h);
|
||||
free(softmax_sum_h);
|
||||
free(softmax_sum_ref_h);
|
||||
free(softmax_stats_h);
|
||||
free(softmax_stats_ref_h);
|
||||
free(contiguous_kv_h);
|
||||
free(kv_cache_ptrs_h);
|
||||
free(kv_cache_block_offsets_h);
|
||||
|
||||
@ -192,11 +192,14 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
|
||||
void* packed_mask_ptr;
|
||||
// The mask input's stride in the N (K-seq) dimension.
|
||||
int64_t packed_mask_stride_in_bytes;
|
||||
// The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max
|
||||
// The Softmax stats vector of layout [total_tokens_q, h, 2], including softmax_max and softmax_sum
|
||||
void* softmax_stats_ptr;
|
||||
// The stride between rows of softmax_stats_ptr
|
||||
// The stride between rows of softmax_stats_ptr, default: h * sizeof(float2)
|
||||
int64_t softmax_stats_stride_in_bytes;
|
||||
|
||||
// The attention sinks (per head).
|
||||
float* attention_sinks;
|
||||
|
||||
// array of length b+1 holding prefix sum of actual q sequence lengths.
|
||||
int* cu_q_seqlens;
|
||||
// array of length b+1 holding prefix sum of actual kv sequence lengths.
|
||||
|
||||
@ -87,6 +87,8 @@ struct Fused_multihead_attention_params_v2
|
||||
fmha::Kv_block_array paged_kv_cache;
|
||||
// The mask to implement drop-out.
|
||||
void* packed_mask_ptr;
|
||||
// The attention sinks (per head).
|
||||
float* attention_sinks;
|
||||
// The O matrix (output).
|
||||
void* o_ptr;
|
||||
// The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max
|
||||
|
||||
@ -1294,27 +1294,47 @@ static void print_tensor(
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static int check_softmax_results(float const* out, float const* ref, size_t b, size_t s, size_t h,
|
||||
static std::pair<int, int> check_softmax_results(float const* out, float const* ref, size_t b, size_t s, size_t h,
|
||||
std::vector<uint32_t>& seqlens, std::vector<int>& cu_seqlens)
|
||||
{
|
||||
int n_errors = 0;
|
||||
int n_errors_max = 0;
|
||||
int n_errors_sum = 0;
|
||||
|
||||
// Check the max
|
||||
for (int b_ = 0; b_ < b; ++b_)
|
||||
{
|
||||
for (int s_ = 0; s_ < seqlens[b_]; ++s_)
|
||||
{
|
||||
for (int h_ = 0; h_ < h; ++h_)
|
||||
{
|
||||
uint64_t idx = cu_seqlens[b_] * h + s_ * h + h_;
|
||||
uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2;
|
||||
float sum = out[idx];
|
||||
float sum_ref = ref[idx];
|
||||
if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01)
|
||||
{
|
||||
n_errors++;
|
||||
n_errors_max++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return n_errors;
|
||||
// Check the sum
|
||||
for (int b_ = 0; b_ < b; ++b_)
|
||||
{
|
||||
for (int s_ = 0; s_ < seqlens[b_]; ++s_)
|
||||
{
|
||||
for (int h_ = 0; h_ < h; ++h_)
|
||||
{
|
||||
uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2 + 1;
|
||||
float sum = out[idx];
|
||||
float sum_ref = ref[idx];
|
||||
if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01)
|
||||
{
|
||||
n_errors_sum++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return {n_errors_max, n_errors_sum};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -23,28 +23,30 @@ using Launch_params = bert::Fused_multihead_attention_launch_params;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n,
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -57,10 +59,10 @@ void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type,
|
||||
void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type,
|
||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, void* q_d, void* kv_d, void* vt_d,
|
||||
void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, void* cu_seqlens_q_d,
|
||||
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, int const runs,
|
||||
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, int const runs,
|
||||
int const warps_m, int const warps_n, bool has_alibi)
|
||||
{
|
||||
|
||||
@ -84,20 +86,22 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty
|
||||
// Softmax.
|
||||
if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16)
|
||||
{
|
||||
run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
run_softmax_fp16(
|
||||
s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
run_softmax_fp32(
|
||||
s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32)
|
||||
{
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, 0.f,
|
||||
warps_n, has_alibi);
|
||||
run_softmax_e4m3(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax,
|
||||
0.f, warps_n, has_alibi);
|
||||
}
|
||||
else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32)
|
||||
{
|
||||
run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1,
|
||||
run_softmax_int8(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1,
|
||||
scale_softmax, 0.f, warps_n, has_alibi);
|
||||
}
|
||||
else
|
||||
@ -148,8 +152,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_mhca& param
|
||||
// types
|
||||
Data_type data_type, Data_type acc_type,
|
||||
// sizes
|
||||
size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, size_t const d_padded,
|
||||
size_t const total,
|
||||
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, const size_t d_padded,
|
||||
const size_t total,
|
||||
// device pointers
|
||||
void* q_packed_d, void* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d, void* p_d,
|
||||
void* s_d,
|
||||
@ -515,17 +519,17 @@ int main(int argc, char** argv)
|
||||
launch_params.use_tma = use_tma;
|
||||
|
||||
// The Q matrix of size S_Q x B x H x D.
|
||||
size_t const q_size = s_q * b * h * d;
|
||||
const size_t q_size = s_q * b * h * d;
|
||||
// The K and V matrices are packed into one big matrix of size S_KV x B x H x 2 x D.
|
||||
size_t const kv_size = s_kv_padded * b * h * 2 * d;
|
||||
const size_t kv_size = s_kv_padded * b * h * 2 * d;
|
||||
// Allocate on the host.
|
||||
float* q_h = (float*) malloc(q_size * sizeof(float));
|
||||
// Allocate on the host.
|
||||
float* kv_h = (float*) malloc(kv_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const q_size_in_bytes = get_size_in_bytes(q_size, data_type);
|
||||
const size_t q_size_in_bytes = get_size_in_bytes(q_size, data_type);
|
||||
// The size in bytes.
|
||||
size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* q_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&q_d, q_size_in_bytes));
|
||||
@ -534,11 +538,11 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&kv_d, kv_size_in_bytes));
|
||||
|
||||
// The mask for dropout.
|
||||
size_t const mask_size = s_q * b * s_kv_padded;
|
||||
const size_t mask_size = s_q * b * s_kv_padded;
|
||||
// Allocate on the host.
|
||||
float* mask_h = (float*) malloc(mask_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8);
|
||||
// Allocate on the device.
|
||||
void* mask_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes));
|
||||
@ -554,28 +558,28 @@ int main(int argc, char** argv)
|
||||
v1 ? 1 : 2);
|
||||
|
||||
// The number of threads per CTA.
|
||||
size_t const threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
// The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension.
|
||||
size_t const mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m);
|
||||
const size_t mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m);
|
||||
// The number of mmas in the N dimension.
|
||||
size_t const mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n);
|
||||
const size_t mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n);
|
||||
// We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask).
|
||||
assert(!v1 || mmas_n <= 4);
|
||||
// The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA.
|
||||
size_t const packed_mask_size = b * mmas_m * threads_per_cta;
|
||||
const size_t packed_mask_size = b * mmas_m * threads_per_cta;
|
||||
// The size in bytes.
|
||||
size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t);
|
||||
// Allocate on the host.
|
||||
uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes);
|
||||
// Allocate on the device.
|
||||
void* packed_mask_d = nullptr;
|
||||
|
||||
// The O matrix is packed as S_Q * B * H * D.
|
||||
size_t const o_size = s_q * b * h * d;
|
||||
const size_t o_size = s_q * b * h * d;
|
||||
// Allocate on the host.
|
||||
float* o_h = (float*) malloc(o_size * sizeof(float));
|
||||
// The size in bytes.
|
||||
size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* o_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes));
|
||||
@ -587,7 +591,7 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaMemset(softmax_max_d, 0x00, sizeof(float) * b * s_q * h));
|
||||
|
||||
// The size in bytes.
|
||||
size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* tmp_d = nullptr;
|
||||
if (data_type != acc_type)
|
||||
@ -599,9 +603,9 @@ int main(int argc, char** argv)
|
||||
float* o_ref_h = (float*) malloc(o_size * sizeof(float));
|
||||
|
||||
// The P matrix is stored as one big matrix of size S_Q x B x H x S_KV.
|
||||
size_t const p_size = s_q * b * h * s_kv_padded;
|
||||
const size_t p_size = s_q * b * h * s_kv_padded;
|
||||
// The size in bytes.
|
||||
size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type);
|
||||
// Allocate on the device.
|
||||
void* p_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes));
|
||||
@ -614,7 +618,7 @@ int main(int argc, char** argv)
|
||||
#endif // defined(STORE_P)
|
||||
|
||||
// The size in bytes of the S matrix (the data type may be different from P for int8).
|
||||
size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type);
|
||||
// Allocate on the device.
|
||||
void* s_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes));
|
||||
@ -634,9 +638,9 @@ int main(int argc, char** argv)
|
||||
|
||||
// WAR fOR MISSING CUBLAS FP8 NN SUPPORT.
|
||||
// Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V.
|
||||
size_t const v_size = s_kv_padded * b * h * d;
|
||||
const size_t v_size = s_kv_padded * b * h * d;
|
||||
// The size in bytes.
|
||||
size_t const v_size_in_bytes = get_size_in_bytes(v_size, data_type);
|
||||
const size_t v_size_in_bytes = get_size_in_bytes(v_size, data_type);
|
||||
float* vt_h = (float*) malloc(v_size * sizeof(float));
|
||||
void* vt_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&vt_d, v_size_in_bytes));
|
||||
@ -676,7 +680,7 @@ int main(int argc, char** argv)
|
||||
= [min_s, fix_s, b](int s, std::vector<uint32_t>& seqlens, std::vector<int>& cu_seqlens, void** cu_seqlens_d)
|
||||
{
|
||||
std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(),
|
||||
[=](uint32_t const)
|
||||
[=](const uint32_t)
|
||||
{
|
||||
if (fix_s)
|
||||
{
|
||||
@ -728,7 +732,7 @@ int main(int argc, char** argv)
|
||||
void* kv_packed_d = nullptr;
|
||||
FMHA_CHECK_CUDA(cudaMalloc(&kv_packed_d, kv_packed_size_in_bytes));
|
||||
|
||||
size_t const o_packed_size = cu_seqlens_q.back() * h * d;
|
||||
const size_t o_packed_size = cu_seqlens_q.back() * h * d;
|
||||
// Allocate on the host.
|
||||
float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float));
|
||||
float* o_ref_packed_h = (float*) malloc(o_packed_size * sizeof(float));
|
||||
|
||||
@ -12,9 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
{
|
||||
run_softmax<fmha::bf16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<fmha::bf16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
|
||||
b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -12,9 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
{
|
||||
run_softmax<uint16_t, uint16_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<uint16_t, uint16_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b,
|
||||
h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -12,9 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
{
|
||||
run_softmax<fmha::fp16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f,
|
||||
softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<fmha::fp16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
|
||||
b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -12,10 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi)
|
||||
{
|
||||
run_softmax<fmha::e4m3_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<fmha::e4m3_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer,
|
||||
b, h, 0.f, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <fmha/numeric_types.h>
|
||||
#include <fmha/utils.h>
|
||||
@ -33,6 +34,8 @@ struct Softmax_params
|
||||
Src_type const* src;
|
||||
// Masks.
|
||||
int8_t const* mask;
|
||||
// Attention sinks (per head).
|
||||
float const* attention_sinks;
|
||||
// Softmax sum pointer.
|
||||
float* softmax_sum;
|
||||
// ALiBi
|
||||
@ -148,7 +151,8 @@ static inline __device__ float apply_exp_(float x, float max)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int N>
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32)
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32,
|
||||
float& max_fp32, float const attention_sink)
|
||||
{
|
||||
|
||||
// Apply the masks.
|
||||
@ -159,7 +163,6 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Compute the max inside the thread.
|
||||
float max_fp32 = -HUGE_VALF;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -233,7 +236,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Normalize.
|
||||
float inv_sum_fp32 = 1.f / sum_fp32;
|
||||
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -244,7 +247,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int N>
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32)
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32,
|
||||
float& max_fp32, float const attention_sink)
|
||||
{
|
||||
// Apply the masks.
|
||||
#pragma unroll
|
||||
@ -255,7 +259,6 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Compute the max inside the thread.
|
||||
float max_fp32 = -HUGE_VALF;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -401,7 +404,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Normalize.
|
||||
float inv_sum_fp32 = 1.f / sum_fp32;
|
||||
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -413,7 +416,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int N>
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32)
|
||||
static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32,
|
||||
float& max_fp32, float const attention_sink)
|
||||
{
|
||||
|
||||
// Apply the masks.
|
||||
@ -427,7 +431,6 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Compute the max inside the thread.
|
||||
float max_fp32 = -HUGE_VALF;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -824,7 +827,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma
|
||||
}
|
||||
|
||||
// Normalize.
|
||||
float inv_sum_fp32 = 1.f / sum_fp32;
|
||||
float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32));
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
@ -994,13 +997,26 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params)
|
||||
}
|
||||
}
|
||||
|
||||
// The attention sink value.
|
||||
float attention_sink = -FLT_MAX;
|
||||
if (params.attention_sinks != nullptr)
|
||||
{
|
||||
attention_sink = params.attention_sinks[hi];
|
||||
}
|
||||
|
||||
// Do the reduction.
|
||||
float sum_fp32 = 0.f;
|
||||
reduce(data_fp32, mask_, params.warps_n, sum_fp32);
|
||||
float max_fp32 = -HUGE_VALF;
|
||||
reduce(data_fp32, mask_, params.warps_n, sum_fp32, max_fp32, attention_sink);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int sum_s = params.cu_q_seqlens[bi];
|
||||
params.softmax_sum[sum_s * params.h + si * params.h + hi] = sum_fp32;
|
||||
// [B, S, H, 2] {max, sum} float
|
||||
if (hi < params.h)
|
||||
{
|
||||
params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2] = max_fp32;
|
||||
params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2 + 1] = sum_fp32;
|
||||
}
|
||||
}
|
||||
// Reconvert to half.
|
||||
DstX_type data_dst[VECs_PER_THREAD];
|
||||
@ -1025,9 +1041,9 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Dst_type, typename Src_type>
|
||||
void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum, void* cu_q_seqlens, int s_inner,
|
||||
int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n,
|
||||
bool has_alibi)
|
||||
void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum,
|
||||
void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
{
|
||||
|
||||
Softmax_params<Dst_type, Src_type> params;
|
||||
@ -1039,6 +1055,7 @@ void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum
|
||||
params.softmax_sum = reinterpret_cast<float*>(softmax_sum);
|
||||
params.cu_q_seqlens = reinterpret_cast<int*>(cu_q_seqlens);
|
||||
params.mask = reinterpret_cast<int8_t const*>(mask);
|
||||
params.attention_sinks = reinterpret_cast<float const*>(attention_sinks);
|
||||
params.has_alibi = has_alibi;
|
||||
|
||||
// The dimensions and precomputed values.
|
||||
|
||||
@ -12,10 +12,10 @@
|
||||
|
||||
#include "softmax_impl.h"
|
||||
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d,
|
||||
int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1,
|
||||
int warps_n, bool has_alibi)
|
||||
void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d,
|
||||
void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax,
|
||||
float softcapping_scale_bmm1, int warps_n, bool has_alibi)
|
||||
{
|
||||
run_softmax<int8_t, int32_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, scale_bmm1,
|
||||
scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
run_softmax<int8_t, int32_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h,
|
||||
scale_bmm1, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi);
|
||||
}
|
||||
|
||||
@ -1379,6 +1379,19 @@ __device__ inline ThrdRegRowMax mergeRowMax(
|
||||
return mergedRowMax;
|
||||
}
|
||||
|
||||
__device__ inline void addAttentionSinks(
|
||||
ThrdRegRowMax& globalRowSum, ThrdRegRowMax const globalRowMax, float const* attentionSinks)
|
||||
{
|
||||
for (uint32_t i = 0; i < globalRowSum.size; i++)
|
||||
{
|
||||
uint32_t srcOffset = warp_size * i + laneId();
|
||||
if (srcOffset < headGrpSize)
|
||||
{
|
||||
globalRowSum[i] += expf(attentionSinks[srcOffset] - globalRowMax[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef NDEBUG
|
||||
__device__ __forceinline__
|
||||
#else
|
||||
@ -1405,6 +1418,7 @@ CUBIN_EXPORT __global__
|
||||
#if SPEC_DEC
|
||||
MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32)].
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#ifdef NDEBUG
|
||||
KVCacheList<usePagedKVCache> const& cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
@ -2371,6 +2385,12 @@ CUBIN_EXPORT __global__
|
||||
float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F);
|
||||
if (seqIterInit < nbSeqIters)
|
||||
{ // otherwise rcpRowSum will be NAN.
|
||||
// The attention sinks are moved to the multi-block reduction part if the multi-block is enabled.
|
||||
if (!isMultiBlock && attentionSinks != nullptr)
|
||||
{
|
||||
// Attention sinks are per head.
|
||||
addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp);
|
||||
}
|
||||
ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum);
|
||||
#if LOW_PREC_OUTPUT
|
||||
voScale *= rcpOutScale[0];
|
||||
@ -2559,6 +2579,11 @@ CUBIN_EXPORT __global__
|
||||
assert(std::isfinite(mergedRowSum[0]));
|
||||
}
|
||||
}
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
// Attention sinks are per head.
|
||||
addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp);
|
||||
}
|
||||
__syncthreads();
|
||||
rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum));
|
||||
GemmOutRegTile const mergedOutTile = toFp16(sumAcc);
|
||||
@ -2615,6 +2640,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
|
||||
MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32))] uint2 (each bit represents mask for one col
|
||||
// position).
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
KVCacheList<usePagedKVCache> const cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
BeamSearchParams const beamSearchParams,
|
||||
@ -2640,7 +2666,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
|
||||
#if SPEC_DEC
|
||||
mask,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
beamSearchParams,
|
||||
#endif
|
||||
@ -2667,6 +2693,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
@ -2760,7 +2787,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SPEC_DEC
|
||||
mask,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
beamSearchParams,
|
||||
#endif
|
||||
@ -2788,7 +2815,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SPEC_DEC
|
||||
mask,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if BEAM_WIDTH > 1
|
||||
beamSearchParams,
|
||||
#endif
|
||||
|
||||
@ -101,6 +101,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
@ -140,6 +141,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
|
||||
@ -428,6 +428,7 @@ __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
|
||||
__device__ void storeGemm0AccToShm(
|
||||
uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc);
|
||||
__device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec);
|
||||
__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound);
|
||||
#else
|
||||
__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
|
||||
__device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd);
|
||||
@ -453,7 +454,8 @@ __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec
|
||||
template <bool dstIsStrided = false, typename DstHead>
|
||||
__device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst,
|
||||
SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar,
|
||||
ShmQWiseVec const& accColSum, uint32_t nbKHeads = 0 /* only for final result in spec dec. */);
|
||||
ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec,
|
||||
uint32_t nbKHeads = 0 /* only for final result in spec dec. */);
|
||||
#else
|
||||
__device__ void transposeVTile(
|
||||
uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src);
|
||||
@ -651,6 +653,7 @@ CUBIN_EXPORT __global__
|
||||
#else
|
||||
IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads],
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
KVCacheList<usePagedKVCache> const cacheList,
|
||||
#if USE_BEAM_SEARCH
|
||||
BeamSearchParams const beamSearchParams,
|
||||
@ -1252,7 +1255,7 @@ CUBIN_EXPORT __global__
|
||||
IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast<IOHead>();
|
||||
#if SWAP_AB
|
||||
finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
|
||||
smem.gemm1WarpGrpBar, smem.gemm1AccColSum);
|
||||
smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr);
|
||||
#else
|
||||
finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
|
||||
smem.gemm1AccColSum, 1, ctaNbValidTokens);
|
||||
@ -1262,9 +1265,16 @@ CUBIN_EXPORT __global__
|
||||
{
|
||||
uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp);
|
||||
OutputHead* const dst = &output[outOffset];
|
||||
ShmQWiseVec const* attentionSinksVec = nullptr;
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
attentionSinksVec
|
||||
= reinterpret_cast<ShmQWiseVec const*>(attentionSinks + headGrpSize * idxHeadGrp);
|
||||
}
|
||||
#if SWAP_AB
|
||||
finalizeAndWriteOut_sync<SPEC_DEC>(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc,
|
||||
xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, nbKHeads);
|
||||
xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec,
|
||||
nbKHeads);
|
||||
#else
|
||||
finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale,
|
||||
smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens);
|
||||
@ -1585,6 +1595,17 @@ CUBIN_EXPORT __global__
|
||||
}
|
||||
unused(bar.consumed.arrive());
|
||||
}
|
||||
// Add the attention sinks.
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
for (uint32_t i = 0; i < headsPerWarp; i++)
|
||||
{
|
||||
uint32_t const idxHead = wid + nbMathWarps * i;
|
||||
float sink = expf(
|
||||
attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max);
|
||||
states[i].sum += sink;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp);
|
||||
auto const dst = &output[outOffset];
|
||||
@ -2029,6 +2050,22 @@ __device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smem
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound)
|
||||
{
|
||||
RegColWiseVec ret;
|
||||
constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols);
|
||||
auto const idx = laneId() % nbThrdsPerInstNBase;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++)
|
||||
{
|
||||
static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols));
|
||||
ret[i] = reinterpret_cast<
|
||||
Vec<Vec<float, GmmaAccCoreMat::cols>, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>(
|
||||
gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd)
|
||||
{
|
||||
uint32_t const idxInQuad = laneId() % 4;
|
||||
@ -2878,12 +2915,19 @@ __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRa
|
||||
template <bool dstIsStrided, typename DstHead>
|
||||
__device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst,
|
||||
SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar,
|
||||
ShmQWiseVec const& accColSum, uint32_t nbKHeads)
|
||||
ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads)
|
||||
{
|
||||
// @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp
|
||||
// static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of
|
||||
// mufu.rcp");
|
||||
auto const regColSum = loadShmColWiseVecWithDup(accColSum);
|
||||
auto regColSum = loadShmColWiseVecWithDup(accColSum);
|
||||
if (attentionSinksVec != nullptr)
|
||||
{
|
||||
auto const regAccColMax = loadShmColWiseVecWithDup(accColMax);
|
||||
auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1);
|
||||
auto regColSinks = expf(regAttentionSinks - regAccColMax);
|
||||
regColSum = regColSum + regColSinks;
|
||||
}
|
||||
auto const regOutScale = __frcp_rn(regColSum) * xvoScale;
|
||||
rescaleAcc(acc, regOutScale);
|
||||
|
||||
@ -3175,6 +3219,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
InputHead const* q,
|
||||
#endif
|
||||
float const* attentionSinks, // [headGrpSize]
|
||||
#if USE_PAGED_KV_CACHE
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
|
||||
@ -3286,7 +3331,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
q,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if USE_BEAM_SEARCH
|
||||
beamSearchParams,
|
||||
#endif
|
||||
@ -3322,7 +3367,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#else
|
||||
q,
|
||||
#endif
|
||||
cacheList,
|
||||
attentionSinks, cacheList,
|
||||
#if USE_BEAM_SEARCH
|
||||
beamSearchParams,
|
||||
#endif
|
||||
|
||||
@ -1859,12 +1859,13 @@ CUtensorMap makeTensorMapForQ(
|
||||
#endif // IS_MLA
|
||||
|
||||
void launchMLA(cudaDeviceProp const& prop,
|
||||
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
|
||||
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
|
||||
float qScale, OutputHead* output, InputHead const* q,
|
||||
float* attentionSinks, // [headGrpSize], not supported.
|
||||
#if USE_PAGED_KV_CACHE
|
||||
GMemCacheHead* pool, // global pool of pages
|
||||
GMemCacheHead* pool, // global pool of pages
|
||||
KVCachePageIndex const*
|
||||
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
|
||||
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
|
||||
#else
|
||||
GMemKVCacheHead* kvCacheData,
|
||||
#endif
|
||||
|
||||
@ -45,7 +45,7 @@ using Vector = Matrix<Type, Size, 1>;
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize)
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
{
|
||||
uint32_t const nbTiles = divUp(seqLen, tileSize);
|
||||
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
|
||||
@ -113,6 +113,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
}
|
||||
rowSum += tileRowSum;
|
||||
}
|
||||
|
||||
// Add the attention sinks.
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
for (uint32_t i = 0; i < headGrpSize; i++)
|
||||
{
|
||||
rowSum[i] += expf(attentionSinks[i] - rowMax[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out
|
||||
= gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array());
|
||||
std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); });
|
||||
@ -123,7 +133,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
||||
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize)
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);
|
||||
@ -143,7 +153,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
#else
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize)
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
{
|
||||
#endif
|
||||
float const rcpXScale = 1.f / xScale;
|
||||
@ -184,7 +194,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, Eigen::Dynamic, Eigen::RowMajor> x
|
||||
= (gemm0Acc.colwise() - rowMax).array().exp().eval();
|
||||
Eigen::Vector<float, headGrpSize> const rowSum = x.rowwise().sum().eval();
|
||||
Eigen::Vector<float, headGrpSize> rowSum = x.rowwise().sum().eval();
|
||||
|
||||
std::for_each(x.data(), x.data() + x.size(), [&](float& e) { e = float(MathElem(e * rcpXScale)); });
|
||||
|
||||
@ -200,6 +210,18 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add the attention sinks.
|
||||
#if !SPEC_DEC
|
||||
if (attentionSinks != nullptr)
|
||||
{
|
||||
for (uint32_t i = 0; i < headGrpSize; i++)
|
||||
{
|
||||
rowSum[i] += expf(attentionSinks[i] - rowMax[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out
|
||||
= gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array());
|
||||
std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); });
|
||||
@ -217,7 +239,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
||||
refAttention<prec, isPaged, useBeamSearch>(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, \
|
||||
CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, float kvScale, float xScale, \
|
||||
uint32_t slidingWinSize)
|
||||
uint32_t slidingWinSize, float* attentionSinks)
|
||||
#endif
|
||||
INSTANTIATE_refAttention(InputElem, false, false);
|
||||
INSTANTIATE_refAttention(InputElem, false, true);
|
||||
|
||||
@ -83,7 +83,7 @@ struct CacheSeq<true, true>
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize);
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
||||
|
||||
template <typename MathElem, bool isPaged, bool useBeamSearch>
|
||||
#if SPEC_DEC
|
||||
@ -93,7 +93,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti
|
||||
#else
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize);
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
||||
#endif
|
||||
|
||||
template <uint32_t ropeStyle>
|
||||
|
||||
@ -130,7 +130,7 @@ template <uint32_t nbKHeads>
|
||||
#endif
|
||||
#endif
|
||||
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
|
||||
bool saveData = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
|
||||
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
|
||||
{
|
||||
#if IS_MLA
|
||||
if (nbKHeads != 1)
|
||||
@ -613,6 +613,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate the attention sinks (per head)
|
||||
auto attentionSinks = ManagedMemBuf<float>(nbQHeads);
|
||||
// The attention sinks ptr.
|
||||
float* attentionSinksPtr = hasAttentionSinks ? reinterpret_cast<float*>(attentionSinks.get()) : nullptr;
|
||||
// Initialize the attention sinks (use large values to detect the potential bugs).
|
||||
for (uint32_t i = 0; i < nbQHeads; i++)
|
||||
{
|
||||
// Range: [2, 5]
|
||||
attentionSinks.get()[i] = 2.f + float(i % 4);
|
||||
}
|
||||
|
||||
if (verbose)
|
||||
{
|
||||
printf("migrating data to gpu\n");
|
||||
@ -640,6 +651,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
#if BEAM_WIDTH > 1
|
||||
cacheIndir.prefetch(dev, stream);
|
||||
#endif
|
||||
attentionSinks.prefetch(dev, stream);
|
||||
};
|
||||
prefetchToDevice(device);
|
||||
checkCuda(cudaMemsetAsync(semaphores.get(), 0, 4 * nbSemaphores, stream));
|
||||
@ -720,6 +732,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
&qHeads[0][0][0],
|
||||
#endif
|
||||
#endif
|
||||
attentionSinksPtr,
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
|
||||
cacheKHeads.get(), cacheVHeads.get(),
|
||||
#else
|
||||
@ -1028,10 +1041,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
hostMask, qSeqLen, q_len);
|
||||
#else
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refOutput;
|
||||
auto const refAttentionSinks
|
||||
= hasAttentionSinks ? attentionSinksPtr + headGrpSize * idxKHead : nullptr;
|
||||
if (useQGMMA)
|
||||
{
|
||||
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
|
||||
refAttentionSinks);
|
||||
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
}
|
||||
@ -1039,8 +1055,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
{
|
||||
// refOutput = refFlashAttention<InputElem, 64>(&qHeads[req][b][headGrpSize * idxKHead],
|
||||
// kCacheSeq, vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale);
|
||||
refOutput = refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
refOutput
|
||||
= refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, vCacheSeq,
|
||||
seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks);
|
||||
}
|
||||
#endif
|
||||
if (lowPrecOutput)
|
||||
@ -1196,11 +1213,23 @@ TEST(RefCheck, llama_V2_70b)
|
||||
runTest<2>(2, 514, false, true);
|
||||
runTest<1>(1, 4096, false, true);
|
||||
#if SLIDING_WINDOW
|
||||
runTest<2>(2, 4096, false, true, false, false, ~0, 256);
|
||||
runTest<2>(2, 400, false, true, false, false, ~0U, 256);
|
||||
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
|
||||
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
|
||||
#endif
|
||||
runTest<8>(120, 367, false, true);
|
||||
// runTest<8>(1792, 2048, false, true);
|
||||
runTest<8>(1792, 2048, false, true);
|
||||
}
|
||||
|
||||
TEST(RefCheck, attention_sinks)
|
||||
{
|
||||
auto runAttentionSinksTest = [](uint32_t batchSize, uint32_t seqLen)
|
||||
{ runTest<8>(batchSize, seqLen, false, true, false, false, /*hasAttentionSinks*/ true); };
|
||||
|
||||
runAttentionSinksTest(2, 2);
|
||||
runAttentionSinksTest(2, 15);
|
||||
runAttentionSinksTest(2, 256);
|
||||
runAttentionSinksTest(2, 514);
|
||||
runAttentionSinksTest(1, 4096);
|
||||
}
|
||||
|
||||
TEST(Perf, tracing_long)
|
||||
@ -1264,7 +1293,7 @@ TEST(Perf, mlperf_gptj)
|
||||
#ifndef NDEBUG
|
||||
GTEST_SKIP() << "Skipping perf tests for debug build";
|
||||
#endif
|
||||
runTest<32>(396, 800 + 224, true, false, false, false, 800);
|
||||
runTest<32>(396, 800 + 224, true, false, false, false, false, 800);
|
||||
}
|
||||
|
||||
TEST(Perf, mlperf_llama)
|
||||
|
||||
@ -53,6 +53,7 @@ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
|
||||
using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;
|
||||
using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner;
|
||||
using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
|
||||
using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams;
|
||||
using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation;
|
||||
|
||||
static BufferManager::CudaStreamPtr streamPtr;
|
||||
@ -980,11 +981,11 @@ public:
|
||||
auto stream = streamPtr->get();
|
||||
MoeMinLatencyParams min_latency_params;
|
||||
#ifdef USING_OSS_CUTLASS_MOE_GEMM
|
||||
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr,
|
||||
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true,
|
||||
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
|
||||
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
|
||||
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
|
||||
mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
|
||||
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
|
||||
mFinalOutput + mFinalOutputSize * mBufferIndex,
|
||||
@ -992,11 +993,11 @@ public:
|
||||
/*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex],
|
||||
/*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream);
|
||||
#else
|
||||
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr,
|
||||
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true,
|
||||
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
|
||||
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
|
||||
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
|
||||
mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
|
||||
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
|
||||
mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
|
||||
mFinalOutput + mFinalOutputSize * mBufferIndex,
|
||||
|
||||
@ -30,6 +30,11 @@ void tensorrt_llm::batch_manager::AssignReqSeqSlots::operator()(SequenceSlotMana
|
||||
{
|
||||
for (auto const& llmReq : requests)
|
||||
{
|
||||
if (llmReq->isDisaggGenerationInitState())
|
||||
{
|
||||
// Skip assigning sequence slot for DISAGG_GENERATION_INIT request
|
||||
continue;
|
||||
}
|
||||
auto const isReqNew = (llmReq->isContextInitState() && !llmReq->mSeqSlot)
|
||||
|| (llmReq->isDisaggGenerationTransmissionComplete());
|
||||
if (isReqNew && llmReq->getReturnPerfMetrics())
|
||||
|
||||
@ -360,7 +360,7 @@ void CacheFormatter::format(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
|
||||
if (connections.size() > 1)
|
||||
@ -712,7 +712,7 @@ void CacheFormatter::unformat(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
if (pickUpConnections.size() > 1)
|
||||
{
|
||||
@ -846,16 +846,23 @@ void CacheFormatter::unformat(TransferSession& session)
|
||||
}
|
||||
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();
|
||||
int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism;
|
||||
int destPPSize = destConfig.getParallelConfig().mPipelineParallelism;
|
||||
int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size();
|
||||
|
||||
if (selfPPSize == destPPSize)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if (selfNumLayers % selfPPSize != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers must be divisible by pipeline parallelism");
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d",
|
||||
selfNumLayers, selfPPSize);
|
||||
return false;
|
||||
}
|
||||
int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size();
|
||||
int destPPSize = destConfig.getParallelConfig().mPipelineParallelism;
|
||||
if (destNumLayers % destPPSize != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers must be divisible by pipeline parallelism");
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ",
|
||||
destNumLayers, destPPSize);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
@ -76,15 +76,6 @@ public:
|
||||
|
||||
/// @brief Destructor.
|
||||
virtual ~BaseCacheFormatter() = default;
|
||||
|
||||
// TODO: better way for context/generation tagging
|
||||
void markAsSender(bool isSender)
|
||||
{
|
||||
kvCacheMeasureHelper.markAsSender(isSender);
|
||||
}
|
||||
|
||||
protected:
|
||||
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
|
||||
};
|
||||
|
||||
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
|
||||
|
||||
@ -44,14 +44,16 @@ namespace tensorrt_llm::batch_manager
|
||||
|
||||
using SizeType32 = CreateNewDecoderRequests::SizeType32;
|
||||
using TensorPtr = CreateNewDecoderRequests::TensorPtr;
|
||||
using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffers& inputBuffers,
|
||||
ITensor& sequenceLengths, SizeType32 beamWidth, runtime::BufferManager const& manager,
|
||||
runtime::CudaStream const& stream)
|
||||
ITensor& sequenceLengths, SizeType32 beamWidth, runtime::CudaStream const& stream)
|
||||
{
|
||||
auto const bufferManager = BufferManager{std::make_shared<CudaStream>(stream.get())};
|
||||
|
||||
auto const batchSize = contextRequests.size();
|
||||
auto batchSlotsView = tr::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize);
|
||||
auto fillValuesView = tr::ITensor::slice(inputBuffers.fillValues, 0, batchSize);
|
||||
@ -79,8 +81,8 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe
|
||||
auto batchSlotsDeviceView = tr::ITensor::slice(inputBuffers.setupBatchSlotsDevice, 0, batchSize);
|
||||
auto fillValuesViewDevice = tr::ITensor::slice(inputBuffers.fillValuesDevice, 0, batchSize);
|
||||
|
||||
manager.copy(*batchSlotsView, *batchSlotsDeviceView);
|
||||
manager.copy(*fillValuesView, *fillValuesViewDevice);
|
||||
bufferManager.copy(*batchSlotsView, *batchSlotsDeviceView);
|
||||
bufferManager.copy(*fillValuesView, *fillValuesViewDevice);
|
||||
tr::kernels::invokeFillBatch(sequenceLengths, *batchSlotsDeviceView, beamWidth, *fillValuesViewDevice, stream);
|
||||
}
|
||||
}
|
||||
@ -127,10 +129,10 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe
|
||||
std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>,
|
||||
std::vector<executor::LookaheadDecodingConfig>>
|
||||
CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
|
||||
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
|
||||
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
|
||||
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
|
||||
SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const
|
||||
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, nvinfer1::DataType logitsType,
|
||||
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream,
|
||||
CudaStream const& decoderStream, SizeType32 maxSequenceLength, SizeType32 beamWidth,
|
||||
OptionalRef<MedusaBuffers const> medusaBuffers) const
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
NVTX3_SCOPED_RANGE(CreateNewDecoderRequests);
|
||||
@ -141,13 +143,13 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
|
||||
|
||||
if (!finishedContextRequests.empty())
|
||||
{
|
||||
copySequenceLengths(finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth,
|
||||
bufferManager, runtimeStream);
|
||||
copySequenceLengths(
|
||||
finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, runtimeStream);
|
||||
}
|
||||
|
||||
auto [lookaheadPrompt, lookaheadAlgoConfigs] = createDecoderRequests(finishedContextRequests,
|
||||
inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig,
|
||||
runtimeStream, decoderStream, maxSequenceLength, medusaBuffers);
|
||||
auto [lookaheadPrompt, lookaheadAlgoConfigs]
|
||||
= createDecoderRequests(finishedContextRequests, inputBuffers.inputsIds, decodingConfig, decoderState,
|
||||
logitsType, modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, medusaBuffers);
|
||||
|
||||
auto const batchSize = finishedContextRequests.size();
|
||||
|
||||
@ -165,115 +167,122 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
|
||||
std::move(lookaheadAlgoConfigs)};
|
||||
}
|
||||
|
||||
void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request,
|
||||
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
|
||||
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
|
||||
SizeType32 maxSequenceLength)
|
||||
namespace
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(batchSlot >= 0);
|
||||
|
||||
BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};
|
||||
|
||||
auto const batchSize = decoderState.getMaxBatchSize();
|
||||
TLLM_CHECK(0 <= batchSize && batchSlot < batchSize);
|
||||
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth,
|
||||
tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.",
|
||||
beamWidth, maxBeamWidth));
|
||||
auto const& requestIds = request.ids;
|
||||
auto const inputLength = request.inputLen;
|
||||
auto const numDecodingEngineTokens = request.generatedTokensPerEngineStep;
|
||||
void initializeInputLengths(DecodingInput& dJointInput, SizeType32 batchSlot, SizeType32 inputLength,
|
||||
std::optional<SizeType32> maxNewTokensOpt, SizeType32 numDecodingEngineTokens, SizeType32 maxSequenceLength,
|
||||
BufferManager const& manager)
|
||||
{
|
||||
auto const numDecodingDraftEngineTokens = numDecodingEngineTokens - 1;
|
||||
auto const maxNewTokens
|
||||
= request.maxNewTokens.value_or(maxSequenceLength - inputLength - numDecodingDraftEngineTokens);
|
||||
auto const maxNewTokens = maxNewTokensOpt.value_or(maxSequenceLength - inputLength - numDecodingDraftEngineTokens);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(inputLength + maxNewTokens + numDecodingDraftEngineTokens <= maxSequenceLength,
|
||||
tc::fmtstr(
|
||||
"Input length (%d) + max new tokens (%d) + draft tokens (%d) must be less than max sequence length (%d).",
|
||||
inputLength, maxNewTokens, numDecodingDraftEngineTokens, maxSequenceLength));
|
||||
TLLM_CHECK(requestIds->getDataType() == TRTDataType<TokenIdType>::value);
|
||||
auto const endId = request.endId.value_or(-1);
|
||||
|
||||
// input
|
||||
auto& dJointInput = decoderState.getJointDecodingInput();
|
||||
TensorPtr const sequenceLimitLength{
|
||||
ITensor::slice(constPointerCast(dJointInput.sequenceLimitLength), batchSlot, 1)};
|
||||
runtime::kernels::invokeFill(*sequenceLimitLength, inputLength + maxNewTokens, manager.getStream());
|
||||
|
||||
dJointInput.beamWidths.at(batchSlot) = beamWidth;
|
||||
decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens);
|
||||
TensorPtr const inputLengths{ITensor::slice(constPointerCast(dJointInput.lengths), batchSlot, 1)};
|
||||
runtime::kernels::invokeFill(*inputLengths, inputLength, manager.getStream());
|
||||
}
|
||||
|
||||
void initializeRequestIds(DecodingInput& dJointInput, DecodingOutput& dJointOutput, SizeType32 batchSlot,
|
||||
SharedConstPtr const& requestIds, SizeType32 endId, SizeType32 beamWidth, SizeType32 maxSequenceLength,
|
||||
BufferManager const& manager)
|
||||
{
|
||||
TensorPtr const endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchSlot, 1)};
|
||||
runtime::kernels::invokeFill(*endIdTensorPtr, endId, decoderStream);
|
||||
runtime::kernels::invokeFill(*endIdTensorPtr, endId, manager.getStream());
|
||||
|
||||
// fill outputIds with endIds
|
||||
TensorPtr const outputIds = ITensor::slice(dJointOutput.ids, batchSlot, 1);
|
||||
auto outputIdsTileView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, maxSequenceLength}));
|
||||
runtime::kernels::invokeFill(*outputIdsTileView, endId, manager.getStream());
|
||||
|
||||
// copy the request ids into outputIds
|
||||
auto const requestIdsShape = requestIds->getShape();
|
||||
auto outputIdsView = ITensor::view(outputIds, requestIdsShape);
|
||||
manager.copy(*requestIds, *outputIdsView);
|
||||
}
|
||||
|
||||
void initializeBeamSearch(DecodingInput& dJointInput, DecodingOutput& dJointOutput, SizeType32 batchSlot,
|
||||
SizeType32 endId, SizeType32 beamWidth, SizeType32 maxSequenceLength, BufferManager const& manager)
|
||||
{
|
||||
TensorPtr const cumLogProbs = ITensor::slice(dJointOutput.cumLogProbs, batchSlot, 1);
|
||||
runtime::kernels::invokeFill(
|
||||
*IBuffer::slice(cumLogProbs, 1, beamWidth - 1), DecodingOutput::kNegativeInfinity, manager.getStream());
|
||||
|
||||
auto parentIds = ITensor::slice(dJointOutput.parentIds, batchSlot, 1);
|
||||
auto const outputIdsShape = ITensor::makeShape({1, beamWidth, maxSequenceLength});
|
||||
parentIds->reshape(outputIdsShape);
|
||||
manager.setZero(*parentIds);
|
||||
|
||||
auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1);
|
||||
manager.setZero(*cacheIndirectionInput);
|
||||
|
||||
auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1);
|
||||
manager.setZero(*cacheIndirectionOutput);
|
||||
|
||||
auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1);
|
||||
beamHypotheses.init(manager, endId);
|
||||
}
|
||||
|
||||
void initializeEmbeddingBias(DecodingInput& dJointInput, SizeType32 batchSlot,
|
||||
std::optional<TensorPtr> const& embeddingBias, nvinfer1::DataType logitsType,
|
||||
runtime::ModelConfig const& modelConfig, BufferManager const& manager)
|
||||
{
|
||||
TensorPtr const embeddingBiasSlice = ITensor::slice(constPointerCast(dJointInput.embeddingBias), batchSlot, 1);
|
||||
if (request.embeddingBias)
|
||||
if (embeddingBias.has_value())
|
||||
{
|
||||
TLLM_CHECK(request.embeddingBias->getShape().nbDims == 2);
|
||||
TLLM_CHECK(request.embeddingBias->getShape().d[0] == 1);
|
||||
TLLM_CHECK_WITH_INFO(request.embeddingBias->getShape().d[1] == modelConfig.getVocabSize(),
|
||||
auto embeddingBiasTensor = getEmbeddingBias(logitsType, embeddingBias.value());
|
||||
|
||||
TLLM_CHECK(embeddingBiasTensor->getShape().nbDims == 2);
|
||||
TLLM_CHECK(embeddingBiasTensor->getShape().d[0] == 1);
|
||||
TLLM_CHECK_WITH_INFO(embeddingBiasTensor->getShape().d[1] == modelConfig.getVocabSize(),
|
||||
"The embedding bias shape is not as expected. Expected last dimension to be same as vocab size: %d.",
|
||||
modelConfig.getVocabSize());
|
||||
manager.copy(*request.embeddingBias, *embeddingBiasSlice);
|
||||
manager.copy(*embeddingBiasTensor, *embeddingBiasSlice);
|
||||
}
|
||||
else
|
||||
{
|
||||
manager.setZero(*embeddingBiasSlice);
|
||||
}
|
||||
}
|
||||
|
||||
auto setupWords = [](std::vector<runtime::ITensor::SharedPtr>& jointWordsLists, TensorPtr const& requestWordsList,
|
||||
SharedConstPtr& jointWordsPtrs, SharedConstPtr& jointWordsLens, SizeType32& jointMaxWordsLen,
|
||||
SizeType32 batchSlot)
|
||||
void setupWords(std::vector<runtime::ITensor::SharedPtr>& jointWordsLists,
|
||||
std::optional<TensorPtr> const& requestWordsList, SharedConstPtr& jointWordsPtrs, SharedConstPtr& jointWordsLens,
|
||||
SizeType32& jointMaxWordsLen, SizeType32 batchSlot, BufferManager const& manager)
|
||||
{
|
||||
if (requestWordsList.has_value())
|
||||
{
|
||||
if (requestWordsList)
|
||||
{
|
||||
auto const wordsLen = requestWordsList->getShape().d[1];
|
||||
BufferRange<int32_t*>(*constPointerCast(jointWordsPtrs))[batchSlot]
|
||||
= runtime::bufferCast<TokenIdType>(*requestWordsList);
|
||||
runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = wordsLen;
|
||||
// FIXME: this is monotonically growing size
|
||||
jointMaxWordsLen = std::max(static_cast<SizeType32>(wordsLen), jointMaxWordsLen);
|
||||
// Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects
|
||||
TensorPtr wordsList = manager.copyFrom(*requestWordsList.value(), MemoryType::kGPU);
|
||||
wordsList->squeeze(0);
|
||||
|
||||
// NOTE: jointWordsList is not used in gptDecoder, but required to keep <name>WordsList's
|
||||
// memory allocated
|
||||
jointWordsLists[batchSlot] = requestWordsList;
|
||||
}
|
||||
else
|
||||
{
|
||||
runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = 0;
|
||||
}
|
||||
};
|
||||
auto const wordsLen = wordsList->getShape().d[1];
|
||||
BufferRange<int32_t*>(*constPointerCast(jointWordsPtrs))[batchSlot]
|
||||
= runtime::bufferCast<TokenIdType>(*wordsList);
|
||||
runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = wordsLen;
|
||||
// FIXME: this is monotonically growing size
|
||||
jointMaxWordsLen = std::max(static_cast<SizeType32>(wordsLen), jointMaxWordsLen);
|
||||
|
||||
setupWords(dJointInput.stopWordsLists, request.stopWordsList, dJointInput.stopWordsPtrs, dJointInput.stopWordsLens,
|
||||
dJointInput.maxStopWordsLen, batchSlot);
|
||||
|
||||
setupWords(dJointInput.badWordsLists, request.badWordsList, dJointInput.badWordsPtrs, dJointInput.badWordsLens,
|
||||
dJointInput.maxBadWordsLen, batchSlot);
|
||||
|
||||
TensorPtr const sequenceLimitLength{
|
||||
ITensor::slice(constPointerCast(dJointInput.sequenceLimitLength), batchSlot, 1)};
|
||||
runtime::kernels::invokeFill(*sequenceLimitLength, inputLength + maxNewTokens, decoderStream);
|
||||
|
||||
TensorPtr const inputLengths{ITensor::slice(constPointerCast(dJointInput.lengths), batchSlot, 1)};
|
||||
runtime::kernels::invokeFill(*inputLengths, inputLength, decoderStream);
|
||||
|
||||
// output
|
||||
auto& dJointOutput = decoderState.getJointDecodingOutput();
|
||||
auto const outputIdsShape = ITensor::makeShape({1, beamWidth, maxSequenceLength});
|
||||
|
||||
auto finishedSum = ITensor::slice(dJointOutput.finishedSum, batchSlot, 1);
|
||||
manager.setZero(*finishedSum);
|
||||
|
||||
for (SizeType32 ti = 0; ti < decoderState.getMaxDecodingEngineTokens(); ++ti)
|
||||
{
|
||||
TensorPtr const newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, 1);
|
||||
newTokensStepView->squeeze(0);
|
||||
auto newTokensVec = ITensor::slice(newTokensStepView, batchSlot, 1);
|
||||
manager.setZero(*newTokensVec);
|
||||
// NOTE: jointWordsList is not used in gptDecoder, but required to keep <name>WordsList's
|
||||
// memory allocated
|
||||
jointWordsLists[batchSlot] = wordsList;
|
||||
}
|
||||
else
|
||||
{
|
||||
runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = 0;
|
||||
}
|
||||
};
|
||||
|
||||
TensorPtr const finishedStepsSlice = ITensor::slice(decoderState.getFinishReasons(), batchSlot, 1);
|
||||
manager.setZero(*finishedStepsSlice);
|
||||
void initializeLogProbs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SamplingConfig const& samplingConfig,
|
||||
BufferManager const& manager)
|
||||
{
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
|
||||
// cumLogProb is mandatory for beamWidth > 1
|
||||
if ((samplingConfig.cumLogProbs.has_value() && samplingConfig.cumLogProbs->at(0)) || beamWidth > 1)
|
||||
@ -287,49 +296,32 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder
|
||||
auto logProbs = ITensor::slice(dJointOutput.logProbs, batchSlot, 1);
|
||||
manager.setZero(*logProbs);
|
||||
}
|
||||
}
|
||||
|
||||
if (beamWidth > 1)
|
||||
void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeType32 maxDecodingEngineTokens,
|
||||
BufferManager const& manager)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto finishedSum = ITensor::slice(dJointOutput.finishedSum, batchSlot, 1);
|
||||
manager.setZero(*finishedSum);
|
||||
|
||||
for (SizeType32 ti = 0; ti < maxDecodingEngineTokens; ++ti)
|
||||
{
|
||||
TensorPtr const cumLogProbs = ITensor::slice(dJointOutput.cumLogProbs, batchSlot, 1);
|
||||
runtime::kernels::invokeFill(
|
||||
*IBuffer::slice(cumLogProbs, 1, beamWidth - 1), DecodingOutput::kNegativeInfinity, decoderStream);
|
||||
|
||||
auto parentIds = ITensor::slice(dJointOutput.parentIds, batchSlot, 1);
|
||||
parentIds->reshape(outputIdsShape);
|
||||
manager.setZero(*parentIds);
|
||||
|
||||
auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1);
|
||||
manager.setZero(*cacheIndirectionInput);
|
||||
|
||||
auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1);
|
||||
manager.setZero(*cacheIndirectionOutput);
|
||||
|
||||
auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1);
|
||||
beamHypotheses.init(manager, endId);
|
||||
TensorPtr const newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, 1);
|
||||
newTokensStepView->squeeze(0);
|
||||
auto newTokensVec = ITensor::slice(newTokensStepView, batchSlot, 1);
|
||||
manager.setZero(*newTokensVec);
|
||||
}
|
||||
|
||||
// Speculative execution
|
||||
if (numDecodingEngineTokens > 1 || decoderState.getSpeculativeDecodingMode().isDraftTokensExternal())
|
||||
{
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
newRequestSpeculativeDecoding(batchSlot, request, samplingConfig, modelConfig,
|
||||
decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, decoderStream,
|
||||
decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens());
|
||||
}
|
||||
|
||||
// fill outputIds with endIds
|
||||
TensorPtr const outputIds = ITensor::slice(dJointOutput.ids, batchSlot, 1);
|
||||
auto outputIdsTileView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, maxSequenceLength}));
|
||||
runtime::kernels::invokeFill(*outputIdsTileView, endId, decoderStream);
|
||||
|
||||
// copy the request ids into outputIds
|
||||
auto const requestIdsShape = requestIds->getShape();
|
||||
auto outputIdsView = ITensor::view(outputIds, requestIdsShape);
|
||||
manager.copy(*requestIds, *outputIdsView);
|
||||
TensorPtr const finishedStepsSlice = ITensor::slice(dJointOutput.finishReasons, batchSlot, 1);
|
||||
manager.setZero(*finishedStepsSlice);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx,
|
||||
runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig,
|
||||
runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput,
|
||||
@ -557,11 +549,12 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
|
||||
std::tuple<std::vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>>
|
||||
CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
|
||||
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
|
||||
BufferManager const& bufferManager, nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig, runtime::CudaStream const& runtimeStream,
|
||||
runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
|
||||
runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
|
||||
OptionalRef<MedusaBuffers const> medusaBuffers) const
|
||||
{
|
||||
auto const decoderBufferManager = BufferManager{std::make_shared<CudaStream>(decoderStream.get())};
|
||||
|
||||
unsigned decoderInputSize{0};
|
||||
for (auto const& llmReq : finishedContextRequests)
|
||||
{
|
||||
@ -586,26 +579,38 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
|
||||
SizeType32 inputOffset{0};
|
||||
for (auto const& llmReq : finishedContextRequests)
|
||||
{
|
||||
auto const promptLen = llmReq->getPromptLen();
|
||||
auto const& reqTokens = llmReq->getTokens(0);
|
||||
TLLM_CHECK(reqTokens.size() == static_cast<decltype(reqTokens.size())>(promptLen));
|
||||
TensorPtr inputView = ITensor::slice(inputIds, inputOffset, promptLen);
|
||||
bufferManager.copy(reqTokens.data(), *inputView);
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{inputView, promptLen, llmReq->mMaxNewTokens, llmReq->mEndId};
|
||||
|
||||
llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs;
|
||||
|
||||
TLLM_CHECK(llmReq->mSeqSlot.has_value());
|
||||
auto const batchSlot = llmReq->mSeqSlot.value();
|
||||
auto const batchSize = decoderState.getMaxNumSequences();
|
||||
TLLM_CHECK(0 <= batchSlot && batchSlot < batchSize);
|
||||
|
||||
auto const& samplingConfig = llmReq->mSamplingConfig;
|
||||
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
|
||||
TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth,
|
||||
tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.",
|
||||
beamWidth, maxBeamWidth));
|
||||
decoderState.setBeamWidth(batchSlot, beamWidth);
|
||||
|
||||
auto const promptLen = llmReq->getPromptLen();
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{promptLen};
|
||||
|
||||
if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
|
||||
{
|
||||
if (llmReq->hasDraftTokens())
|
||||
{
|
||||
auto const& draftTokens = llmReq->getDraftTokens();
|
||||
decoderRequest.draftTokens = bufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL);
|
||||
// Copy to pinned host memory (don't care about stream of bufferManager)
|
||||
decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL);
|
||||
auto const& draftLogits = llmReq->getDraftLogits();
|
||||
if (draftLogits.has_value())
|
||||
{
|
||||
decoderRequest.draftLogits
|
||||
= retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), bufferManager);
|
||||
= retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager);
|
||||
}
|
||||
decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1;
|
||||
}
|
||||
@ -618,48 +623,77 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
|
||||
{
|
||||
decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens();
|
||||
}
|
||||
if (modelConfig.getSpeculativeDecodingMode().isMedusa())
|
||||
{
|
||||
TLLM_CHECK(medusaBuffers);
|
||||
llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs};
|
||||
// FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
|
||||
// When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
|
||||
decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1);
|
||||
decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1);
|
||||
}
|
||||
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
||||
{
|
||||
lookaheadPrompt.emplace_back(ITensor::slice(decoderRequest.ids, 0, decoderRequest.inputLen));
|
||||
|
||||
auto const& lookaheadRuntimeConfig
|
||||
= llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value());
|
||||
lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig);
|
||||
}
|
||||
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
|
||||
auto& dJointInput = decoderState.getJointDecodingInput();
|
||||
|
||||
auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep;
|
||||
initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens,
|
||||
maxSequenceLength, decoderBufferManager);
|
||||
decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens);
|
||||
|
||||
initializeEmbeddingBias(
|
||||
dJointInput, batchSlot, llmReq->getEmbeddingBias(), logitsType, modelConfig, decoderBufferManager);
|
||||
|
||||
setupWords(dJointInput.badWordsLists, llmReq->getBadWordsList(), dJointInput.badWordsPtrs,
|
||||
dJointInput.badWordsLens, dJointInput.maxBadWordsLen, batchSlot, decoderBufferManager);
|
||||
|
||||
setupWords(dJointInput.stopWordsLists, llmReq->getStopWordsList(), dJointInput.stopWordsPtrs,
|
||||
dJointInput.stopWordsLens, dJointInput.maxStopWordsLen, batchSlot, decoderBufferManager);
|
||||
|
||||
auto& dJointOutput = decoderState.getJointDecodingOutput();
|
||||
|
||||
initializeOutputs(dJointOutput, batchSlot, decoderState.getMaxDecodingEngineTokens(), decoderBufferManager);
|
||||
|
||||
initializeLogProbs(dJointOutput, batchSlot, samplingConfig, decoderBufferManager);
|
||||
|
||||
auto const& reqTokens = llmReq->getTokens(0);
|
||||
TLLM_CHECK(reqTokens.size() == static_cast<decltype(reqTokens.size())>(promptLen));
|
||||
TensorPtr requestIds = ITensor::slice(inputIds, inputOffset, promptLen);
|
||||
// Copy to pinned host memory (don't care about stream of bufferManager)
|
||||
decoderBufferManager.copy(reqTokens.data(), *requestIds);
|
||||
auto const endId = llmReq->mEndId.value_or(-1);
|
||||
|
||||
initializeRequestIds(dJointInput, dJointOutput, batchSlot, requestIds, endId, beamWidth, maxSequenceLength,
|
||||
decoderBufferManager);
|
||||
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
decoderRequest.eagleConfig
|
||||
= llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig();
|
||||
}
|
||||
if (llmReq->getEmbeddingBias().has_value())
|
||||
{
|
||||
decoderRequest.embeddingBias = getEmbeddingBias(logitsType, llmReq->getEmbeddingBias().value());
|
||||
}
|
||||
if (llmReq->getBadWordsList().has_value())
|
||||
{
|
||||
// Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects
|
||||
decoderRequest.badWordsList = bufferManager.copyFrom(*llmReq->getBadWordsList().value(), MemoryType::kGPU);
|
||||
decoderRequest.badWordsList->squeeze(0);
|
||||
}
|
||||
if (llmReq->getStopWordsList().has_value())
|
||||
{
|
||||
decoderRequest.stopWordsList
|
||||
= bufferManager.copyFrom(*llmReq->getStopWordsList().value(), MemoryType::kGPU);
|
||||
decoderRequest.stopWordsList->squeeze(0);
|
||||
initializeBeamSearch(
|
||||
dJointInput, dJointOutput, batchSlot, endId, beamWidth, maxSequenceLength, decoderBufferManager);
|
||||
}
|
||||
|
||||
TLLM_CHECK(llmReq->mSeqSlot.has_value());
|
||||
newRequest(llmReq->mSeqSlot.value(), decoderRequest, llmReq->mSamplingConfig, modelConfig, decoderState,
|
||||
runtimeStream, decoderStream, maxSequenceLength);
|
||||
// Speculative execution
|
||||
if (!decoderState.getSpeculativeDecodingMode().isNone())
|
||||
{
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
|
||||
if (modelConfig.getSpeculativeDecodingMode().isMedusa())
|
||||
{
|
||||
TLLM_CHECK(medusaBuffers);
|
||||
llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs};
|
||||
// FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
|
||||
// When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
|
||||
decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1);
|
||||
decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1);
|
||||
}
|
||||
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
||||
{
|
||||
lookaheadPrompt.emplace_back(requestIds);
|
||||
|
||||
auto const& lookaheadRuntimeConfig
|
||||
= llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value());
|
||||
lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig);
|
||||
}
|
||||
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
|
||||
{
|
||||
decoderRequest.eagleConfig
|
||||
= llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig();
|
||||
}
|
||||
|
||||
newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig,
|
||||
decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream,
|
||||
decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens());
|
||||
}
|
||||
|
||||
decoderRequests.push_back(decoderRequest);
|
||||
|
||||
|
||||
@ -91,6 +91,43 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo)
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void TransferSession::appendMeasure(double delay, double duration, size_t size)
|
||||
{
|
||||
if (!mRecordMeasure)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
|
||||
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
|
||||
}
|
||||
|
||||
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
|
||||
{
|
||||
if (mMeasures.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
// write header if not exist
|
||||
if (outFile.tellp() == 0)
|
||||
{
|
||||
outFile << "RequestID";
|
||||
for (size_t i = 0; i < mMeasures.size(); i++)
|
||||
{
|
||||
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
|
||||
}
|
||||
outFile << '\n';
|
||||
}
|
||||
// write measures
|
||||
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
|
||||
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
|
||||
outFile << reqId;
|
||||
for (auto const& measure : mMeasures)
|
||||
{
|
||||
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
|
||||
}
|
||||
outFile << '\n' << std::flush;
|
||||
}
|
||||
|
||||
class DataResponder::Impl
|
||||
{
|
||||
public:
|
||||
|
||||
@ -97,15 +97,23 @@ private:
|
||||
class TransferSession
|
||||
{
|
||||
public:
|
||||
struct Measure
|
||||
{
|
||||
double delay; // from last token (ctx) or arrival time (gen), in ms
|
||||
double duration; // in ms
|
||||
double bandwidth; // in Gbps
|
||||
};
|
||||
|
||||
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
|
||||
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
|
||||
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr)
|
||||
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
|
||||
: mConnections(std::move(connections))
|
||||
, mDataContext(dataContext)
|
||||
, mSelfState(&selfState)
|
||||
, mOtherState(std::move(otherState))
|
||||
, mBufferManager(&bufferManager)
|
||||
, mRequest(llmRequest)
|
||||
, mRecordMeasure(recordMeasure)
|
||||
{
|
||||
TLLM_CHECK(!mConnections.empty());
|
||||
}
|
||||
@ -163,6 +171,11 @@ public:
|
||||
mRequest = &llmRequest;
|
||||
}
|
||||
|
||||
void appendMeasure(double delay, double duration, size_t size);
|
||||
|
||||
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
|
||||
void exportMeasure(std::ofstream& outFile, bool isContext) const;
|
||||
|
||||
private:
|
||||
std::vector<Connection const*> mConnections;
|
||||
DataContext mDataContext;
|
||||
@ -170,6 +183,8 @@ private:
|
||||
executor::DataTransceiverState mOtherState;
|
||||
runtime::BufferManager const* mBufferManager;
|
||||
LlmRequest const* mRequest;
|
||||
bool mRecordMeasure;
|
||||
std::vector<Measure> mMeasures;
|
||||
};
|
||||
|
||||
// Operators required for data transmission in specific communication protocols.
|
||||
@ -266,79 +281,4 @@ private:
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
};
|
||||
|
||||
class KvCacheMeasureHelper
|
||||
{
|
||||
public:
|
||||
struct Measure
|
||||
{
|
||||
double delay; // from last token (ctx) or arrival time (gen), in ms
|
||||
double duration; // in ms
|
||||
double bandwidth; // in Gbps
|
||||
};
|
||||
|
||||
KvCacheMeasureHelper(std::string output_path)
|
||||
: mOutputPath(std::move(output_path))
|
||||
{
|
||||
}
|
||||
|
||||
void markAsSender(bool isSender)
|
||||
{
|
||||
mIsSender = isSender;
|
||||
}
|
||||
|
||||
void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size)
|
||||
{
|
||||
auto bandwidth = size * 8 / (duration / 1000) / 1e9;
|
||||
if (mOutputPath.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth});
|
||||
}
|
||||
|
||||
~KvCacheMeasureHelper()
|
||||
{
|
||||
if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty())
|
||||
{
|
||||
TLLM_CHECK(mIsSender.has_value());
|
||||
auto rank = mpi::MpiComm::world().getRank();
|
||||
std::string outFilePath
|
||||
= mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv";
|
||||
std::ofstream outFile(outFilePath);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath);
|
||||
|
||||
size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size();
|
||||
|
||||
outFile << "RequestID";
|
||||
for (size_t i = 0; i < numTransferMeasure; i++)
|
||||
{
|
||||
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
|
||||
}
|
||||
outFile << '\n';
|
||||
|
||||
for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure)
|
||||
{
|
||||
outFile << requestID;
|
||||
|
||||
for (auto const& measure : measures)
|
||||
{
|
||||
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
|
||||
}
|
||||
outFile << '\n';
|
||||
}
|
||||
|
||||
outFile.close();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<LlmRequest::RequestIdType, std::vector<Measure>> mRequestKVCacheTranfserMeasure;
|
||||
std::string mOutputPath;
|
||||
std::mutex mMutex;
|
||||
std::optional<bool> mIsSender;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
@ -30,6 +32,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
|
||||
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
|
||||
}
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static fs::path getTransferOutputPath(char const* tag)
|
||||
{
|
||||
auto outputPath = common::getEnvKVCacheTransferOutputPath();
|
||||
if (!outputPath.empty())
|
||||
{
|
||||
auto rank = mpi::MpiComm::world().getRank();
|
||||
auto path = fs::path(outputPath);
|
||||
fs::create_directories(path);
|
||||
return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv");
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mManager{manager}
|
||||
@ -39,7 +56,6 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
{
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
mFormatter->markAsSender(true);
|
||||
}
|
||||
|
||||
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
|
||||
@ -86,7 +102,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
if (it == mRequestToSession.end())
|
||||
{
|
||||
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager);
|
||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr,
|
||||
!common::getEnvKVCacheTransferOutputPath().empty());
|
||||
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
||||
}
|
||||
it->second.setConnection(peerIdx, connection);
|
||||
@ -125,6 +142,17 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
std::unique_lock<std::mutex> lk(mMtxForMap);
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
if (!mMeasuresFile.is_open())
|
||||
{
|
||||
auto outputPath = getTransferOutputPath("send");
|
||||
mMeasuresFile.open(outputPath);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
|
||||
}
|
||||
it->second.exportMeasure(mMeasuresFile, true);
|
||||
}
|
||||
mRequestToSession.erase(it);
|
||||
}
|
||||
|
||||
@ -137,7 +165,6 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
TLLM_CHECK(mFormatter);
|
||||
mFormatter->markAsSender(false);
|
||||
}
|
||||
|
||||
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
|
||||
@ -203,12 +230,24 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
|
||||
}
|
||||
auto const& resource = getReceiveCacheResource(llmRequest);
|
||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
||||
contextState, resource->mBufferManager, &llmRequest);
|
||||
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
|
||||
}
|
||||
|
||||
void DataReceiverImpl::receiveSync(TransferSession& session)
|
||||
{
|
||||
mFormatter->unformat(session);
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
|
||||
if (!mMeasuresFile.is_open())
|
||||
{
|
||||
auto outputPath = getTransferOutputPath("recv");
|
||||
mMeasuresFile.open(outputPath);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
|
||||
}
|
||||
session.exportMeasure(mMeasuresFile, false);
|
||||
}
|
||||
}
|
||||
|
||||
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
|
||||
|
||||
@ -23,6 +23,8 @@
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
struct TransceiverTag
|
||||
@ -67,6 +69,7 @@ private:
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::mutex mMtxForMap;
|
||||
runtime::BufferManager mBufferManager;
|
||||
std::ofstream mMeasuresFile;
|
||||
};
|
||||
|
||||
class DataReceiverImpl : public DataReceiver, public TransceiverTag
|
||||
@ -103,6 +106,8 @@ private:
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
|
||||
std::mutex mProcessIoResouceMutex;
|
||||
std::ofstream mMeasuresFile;
|
||||
std::mutex mMeasuresFileMutex;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
|
||||
continue;
|
||||
}
|
||||
auto const seqSlot = llmReq->mSeqSlot.value();
|
||||
if (llmReq->isContextInitState()
|
||||
&& llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen())
|
||||
if (llmReq->isContextInitState() && llmReq->isFirstContextChunk())
|
||||
{
|
||||
// The request is in the first context forward step (considering kv cache reuse).
|
||||
auto const& guideType = guidedDecodingParams->getGuideType();
|
||||
|
||||
@ -18,20 +18,51 @@
|
||||
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/serialization.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries)
|
||||
KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank,
|
||||
std::optional<SizeType32> attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs)
|
||||
: mRun{true}
|
||||
, mMaxSize{maxKVEventEntries}
|
||||
, mEventId{0}
|
||||
, mAttentionDpRank{attentionDpRank}
|
||||
, mAttentionDpSize{attentionDpSize}
|
||||
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
|
||||
{
|
||||
TLLM_CHECK(mMaxSize > 0);
|
||||
// mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this));
|
||||
if (mAttentionDpRank)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mAttentionDpSize.has_value(), "If attention DP rank is set, the attention DP size must also be set");
|
||||
TLLM_CHECK_WITH_INFO(mAttentionDpRank.value() < mAttentionDpSize.value(),
|
||||
"Attention DP rank must be less than attention DP size");
|
||||
if (mAttentionDpRank.value() == 0)
|
||||
{
|
||||
// Rank 0 will gather events from all other ranks
|
||||
// Need to increase size
|
||||
mMaxSize *= mAttentionDpSize.value();
|
||||
}
|
||||
// Create a communicator to be used for event exchange
|
||||
mMpiComm = std::make_unique<tensorrt_llm::mpi::MpiComm>(COMM_SESSION.split(0, mAttentionDpRank.value()));
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
!mAttentionDpSize.has_value(), "If attention DP rank is not set, the attention DP size must not be set");
|
||||
}
|
||||
mWorkerThread = std::thread([this]() { this->worker(); });
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
if (mAttentionDpRank)
|
||||
{
|
||||
mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpThread(); });
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
KVCacheEventManager::~KVCacheEventManager()
|
||||
@ -40,12 +71,18 @@ KVCacheEventManager::~KVCacheEventManager()
|
||||
mPendingEmptyCV.notify_all();
|
||||
mEmptyCV.notify_all();
|
||||
mWorkerThread.join();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
if (mAttentionDpRank)
|
||||
{
|
||||
mExchangeAttentionDpThread.join();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueCreatedEvent(
|
||||
std::vector<SizeType32> const& numBlocksPerCacheLevel, SizeType32 windowSize)
|
||||
{
|
||||
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize});
|
||||
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize, mAttentionDpRank});
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks, SizeType32 windowSize)
|
||||
@ -68,7 +105,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
|
||||
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority());
|
||||
}
|
||||
|
||||
enqueueEvent({mEventId++, data, windowSize});
|
||||
enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank});
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize)
|
||||
@ -81,13 +118,13 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32
|
||||
}
|
||||
else
|
||||
{
|
||||
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize});
|
||||
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank});
|
||||
}
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize)
|
||||
{
|
||||
enqueueEvent({mEventId++, data, windowSize});
|
||||
enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank});
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event)
|
||||
@ -120,8 +157,76 @@ void KVCacheEventManager::flush()
|
||||
mPendingEmptyCV.notify_one();
|
||||
}
|
||||
|
||||
void KVCacheEventManager::exchangeAttentionDpThread()
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
while (true)
|
||||
{
|
||||
TLLM_CHECK(mAttentionDpRank);
|
||||
|
||||
// Check if any of the ranks have been shutdown
|
||||
int32_t numFinished = 0;
|
||||
int32_t finished = mRun ? 0 : 1;
|
||||
mMpiComm->allreduce(&finished, &numFinished, 1, mpi::MpiType::kINT32, mpi::MpiOp::SUM);
|
||||
if (numFinished > 0)
|
||||
{
|
||||
TLLM_LOG_INFO("One of the rank has been shut down, exiting");
|
||||
break;
|
||||
}
|
||||
|
||||
// If we are not rank 0, send events to rank 0
|
||||
if (mAttentionDpRank.value() != 0)
|
||||
{
|
||||
std::vector<char> serializedEvents;
|
||||
uint64_t numEvents = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(mEventsMutex);
|
||||
serializedEvents = executor::Serialization::serialize(mEvents);
|
||||
numEvents = mEvents.size();
|
||||
mEvents.clear();
|
||||
}
|
||||
uint64_t vecSize = numEvents > 0 ? serializedEvents.size() : 0;
|
||||
mMpiComm->send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize);
|
||||
if (vecSize > 0)
|
||||
{
|
||||
mMpiComm->send(serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0,
|
||||
mpi::MpiTag::kKvCacheEvent);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK(mAttentionDpSize.has_value());
|
||||
// Loop until have received events from all ranks
|
||||
for (int rank = 1; rank < mAttentionDpSize.value(); ++rank)
|
||||
{
|
||||
uint64_t vecSize{0};
|
||||
mMpiComm->recv(&vecSize, 1, mpi::MpiType::kUINT64, rank, mpi::MpiTag::kKvCacheEventSize);
|
||||
if (vecSize > 0)
|
||||
{
|
||||
std::vector<char> serializedEvents(vecSize);
|
||||
mMpiComm->recv(
|
||||
serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, rank, mpi::MpiTag::kKvCacheEvent);
|
||||
|
||||
// Deserialize the events and add them to the local queue
|
||||
auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents);
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(mEventsMutex);
|
||||
mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end());
|
||||
mEmptyCV.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs));
|
||||
}
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif
|
||||
}
|
||||
|
||||
void KVCacheEventManager::worker()
|
||||
{
|
||||
|
||||
while (true)
|
||||
{
|
||||
std::deque<tle::KVCacheEvent> events;
|
||||
@ -151,6 +256,8 @@ void KVCacheEventManager::worker()
|
||||
|
||||
// If there's still too many events, take from the front of the events queue.
|
||||
mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end());
|
||||
|
||||
// Notify the empty condition variable to wake up any waiting threads
|
||||
mEmptyCV.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
@ -504,8 +504,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
|
||||
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
|
||||
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
|
||||
bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
|
||||
, mTokensPerBlock{tokensPerBlock}
|
||||
, mEventManager{std::move(eventManager)}
|
||||
@ -530,7 +529,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
|
||||
TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks...
|
||||
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
|
||||
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
|
||||
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enableHashKey, enablePartialReuse,
|
||||
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
|
||||
copyOnPartialReuse);
|
||||
}
|
||||
|
||||
@ -573,8 +572,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
|
||||
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
|
||||
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
|
||||
bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
: mDataType{dtype}
|
||||
, mWindowSize{windowSize}
|
||||
, mNumPrimaryBlocks{blocksInPrimaryPool}
|
||||
@ -596,7 +594,6 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
|
||||
, mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)}
|
||||
, mReusedTokens{0.0}
|
||||
, mTotalInputTokens{0.0}
|
||||
, mEnableHashKey{enableHashKey}
|
||||
, mEnablePartialReuse{enablePartialReuse}
|
||||
, mCopyOnPartialReuse{copyOnPartialReuse}
|
||||
{
|
||||
@ -920,50 +917,6 @@ void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const
|
||||
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
|
||||
}
|
||||
|
||||
void WindowBlockManager::addBlockToHashMap(BlockPtr const& block)
|
||||
{
|
||||
if (!mEnableHashKey)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto range = mContextBlocksByHash.equal_range(block->getHash());
|
||||
for (auto it = range.first; it != range.second; ++it)
|
||||
{
|
||||
if (it->second == block)
|
||||
{
|
||||
// TODO: change to assert when reused block is added only once
|
||||
TLLM_LOG_TRACE(
|
||||
"Block %d by %zx exists", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
TLLM_LOG_TRACE(
|
||||
"Add block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
|
||||
mContextBlocksByHash.emplace(block->getHash(), std::move(block));
|
||||
}
|
||||
|
||||
void WindowBlockManager::removeBlockFromHashMap(BlockPtr const& block)
|
||||
{
|
||||
if (mContextBlocksByHash.empty() || block->getBlockKey().uniqueTokens.empty())
|
||||
{
|
||||
// Hash key not enabled / Empty block
|
||||
return;
|
||||
}
|
||||
auto range = mContextBlocksByHash.equal_range(block->getHash());
|
||||
TLLM_LOG_TRACE(
|
||||
"Remove block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size());
|
||||
for (auto it = range.first; it != range.second; ++it)
|
||||
{
|
||||
if (it->second == block)
|
||||
{
|
||||
mContextBlocksByHash.erase(it);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// TODO: should be unreachable
|
||||
TLLM_LOG_DEBUG("Trying to remove block %d by %zx that is not in hash map", block->getBlockId(), block->getHash());
|
||||
}
|
||||
|
||||
void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize)
|
||||
{
|
||||
mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock);
|
||||
@ -1104,7 +1057,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
|
||||
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(),
|
||||
matchingBlockId);
|
||||
addBlockToHashMap(matchingBlock);
|
||||
}
|
||||
searchRoot = nullptr; // no matching needed for following blocks
|
||||
}
|
||||
@ -1114,7 +1066,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
mEvictionPolicy->claimBlock(
|
||||
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
|
||||
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId);
|
||||
addBlockToHashMap(matchingBlock);
|
||||
searchRoot = matchingBlock;
|
||||
}
|
||||
onboardBlock(matchingBlock);
|
||||
@ -1145,7 +1096,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
++blockItr;
|
||||
}
|
||||
freeBlock->setHash();
|
||||
addBlockToHashMap(freeBlock);
|
||||
++mMissedBlocks;
|
||||
}
|
||||
}
|
||||
@ -1169,7 +1119,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
++blockItr;
|
||||
}
|
||||
freeBlock->setHash();
|
||||
addBlockToHashMap(freeBlock);
|
||||
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d",
|
||||
mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi);
|
||||
}
|
||||
@ -1369,9 +1318,7 @@ void WindowBlockManager::storeBlocks(
|
||||
if (oldHash != newHash)
|
||||
{
|
||||
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
|
||||
removeBlockFromHashMap(block);
|
||||
block->setHash(newHash);
|
||||
addBlockToHashMap(block);
|
||||
}
|
||||
searchRoot = block;
|
||||
}
|
||||
@ -1408,7 +1355,6 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp
|
||||
if (!block->hasRefs())
|
||||
{
|
||||
mEvictionPolicy->releaseBlock(block);
|
||||
removeBlockFromHashMap(block);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1473,7 +1419,6 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence)
|
||||
if (!block->hasRefs())
|
||||
{
|
||||
mEvictionPolicy->releaseBlock(block, true);
|
||||
removeBlockFromHashMap(block);
|
||||
}
|
||||
// Remove block from allocated blocks
|
||||
allocatedBlocks.pop_back();
|
||||
@ -1616,7 +1561,6 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence)
|
||||
if (!block->hasRefs())
|
||||
{
|
||||
mEvictionPolicy->releaseBlock(block);
|
||||
removeBlockFromHashMap(block);
|
||||
}
|
||||
}
|
||||
// Remove stored block ids in sequence
|
||||
@ -1654,8 +1598,7 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
|
||||
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
|
||||
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
|
||||
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
|
||||
enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, false, enablePartialReuse,
|
||||
copyOnPartialReuse)
|
||||
enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse)
|
||||
{
|
||||
}
|
||||
|
||||
@ -1682,8 +1625,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
|
||||
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
|
||||
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
|
||||
bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
: mMaxBeamWidth(maxBeamWidth)
|
||||
, mDataType(dtype)
|
||||
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
|
||||
@ -1693,10 +1635,9 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
|
||||
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
|
||||
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
|
||||
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
|
||||
enableHashKey, enablePartialReuse, copyOnPartialReuse)
|
||||
enablePartialReuse, copyOnPartialReuse)
|
||||
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
|
||||
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
|
||||
, mEnableHashKey{enableHashKey}
|
||||
{
|
||||
TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow)
|
||||
!= maxAttentionWindowVec.end());
|
||||
@ -1716,12 +1657,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
|
||||
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
|
||||
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
|
||||
bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
|
||||
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
|
||||
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
|
||||
std::move(eventManager), enableHashKey, enablePartialReuse, copyOnPartialReuse)
|
||||
std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
|
||||
{
|
||||
}
|
||||
|
||||
@ -2085,30 +2025,6 @@ void KVCacheManager::addSequence(
|
||||
llmRequest->mRequestId);
|
||||
}
|
||||
mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize);
|
||||
if (mEnableHashKey && llmRequest.has_value() && beamWidth == 1)
|
||||
{
|
||||
constexpr SizeType32 beamIdx = 0;
|
||||
auto const& blockIds = sequence.getCacheBlockIds(windowSize).at(beamIdx);
|
||||
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
|
||||
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(
|
||||
uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), true);
|
||||
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
||||
auto tokensPerBlock = static_cast<size_t>(getTokensPerBlock());
|
||||
for (size_t i = 0; i < blockIds.size(); i++)
|
||||
{
|
||||
auto const& block = mBlockManager.getBlockById(blockIds[i], windowSize);
|
||||
if (i < blockKeys.size())
|
||||
{
|
||||
block->setBlockKey(blockKeys[i], blockKeys[i].uniqueTokens.size() == tokensPerBlock);
|
||||
}
|
||||
else
|
||||
{
|
||||
block->setBlockKey({}, false);
|
||||
}
|
||||
block->setHash();
|
||||
mBlockManager.addBlockToHashMap(block, windowSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
cacheBlockOffsets(sequence, windowSize);
|
||||
}
|
||||
@ -2127,10 +2043,13 @@ void KVCacheManager::addSequence(
|
||||
void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest)
|
||||
{
|
||||
auto const requestId = llmRequest.mRequestId;
|
||||
auto& sequence = getSequence(requestId);
|
||||
if (mEnableBlockReuse && !sequence.isCyclic() && !llmRequest.isDummyRequest())
|
||||
if (mSequences.find(requestId) != mSequences.end())
|
||||
{
|
||||
mBlockManager.storeContextBlocks(sequence, llmRequest);
|
||||
auto& sequence = getSequence(requestId);
|
||||
if (mEnableBlockReuse && !sequence.isCyclic() && !llmRequest.isDummyRequest())
|
||||
{
|
||||
mBlockManager.storeContextBlocks(sequence, llmRequest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -365,4 +365,10 @@ void LlmRequest::moveLoraWeightsToGpu(runtime::BufferManager const& manager)
|
||||
mLoraWeights = gpuLoraWeights;
|
||||
}
|
||||
|
||||
void LlmRequest::removeLoraTensors()
|
||||
{
|
||||
mLoraWeights.reset();
|
||||
mLoraConfig.reset();
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -236,7 +236,7 @@ void MLACacheFormatter::format(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
|
||||
if (connections.size() > 1)
|
||||
@ -433,7 +433,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
|
||||
if (pickUpConnections.size() > 1)
|
||||
@ -583,6 +583,28 @@ void MLACacheFormatter::unformat(TransferSession& session)
|
||||
return false;
|
||||
}
|
||||
|
||||
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();
|
||||
int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism;
|
||||
int destPPSize = destConfig.getParallelConfig().mPipelineParallelism;
|
||||
int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size();
|
||||
|
||||
if (selfPPSize == destPPSize)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if (selfNumLayers % selfPPSize != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d",
|
||||
selfNumLayers, selfPPSize);
|
||||
return false;
|
||||
}
|
||||
if (destNumLayers % destPPSize != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ",
|
||||
destNumLayers, destPPSize);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -591,10 +591,9 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
if (llmRequest->getLoraTaskId().has_value())
|
||||
{
|
||||
auto taskId = llmRequest->getLoraTaskId().value();
|
||||
try
|
||||
{
|
||||
return mHostLoraCache->determineNumPages(taskId);
|
||||
return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value());
|
||||
}
|
||||
catch (std::runtime_error& e)
|
||||
{
|
||||
@ -602,16 +601,6 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe
|
||||
{
|
||||
return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value());
|
||||
}
|
||||
if (!llmRequest->getLoraWeights().has_value())
|
||||
{
|
||||
auto const reqId = llmRequest->mRequestId;
|
||||
std::string errMsg
|
||||
= "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task "
|
||||
+ std::to_string(taskId) + " that's not found in LoRA CPU cache."
|
||||
" Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization,"
|
||||
" so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported.";
|
||||
throw PeftTaskNotCachedException(errMsg);
|
||||
}
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
@ -693,7 +693,7 @@ std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::c
|
||||
kvCacheConfig.getEventBufferMaxSize() > 0
|
||||
? std::make_unique<kv_cache_manager::KVCacheEventManager>(kvCacheConfig.getEventBufferMaxSize())
|
||||
: nullptr,
|
||||
false, kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse());
|
||||
kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse());
|
||||
|
||||
reshapeKvTensors(kvCacheManager->getOffsetTableDimensions());
|
||||
|
||||
@ -1866,9 +1866,9 @@ void TrtGptModelInflightBatching::setupDecoderStep(
|
||||
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
|
||||
|
||||
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
|
||||
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests,
|
||||
mRuntime->getBufferManager(), logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(),
|
||||
*mDecoder->getDecoderStream(), getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers);
|
||||
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType,
|
||||
inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(),
|
||||
mOperatingBeamWidth, buffers.mMedusaBuffers);
|
||||
|
||||
auto const localBatchSize = batchSlots->getSize();
|
||||
if (localBatchSize > 0)
|
||||
|
||||
@ -55,6 +55,7 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
T const* qkv_bias;
|
||||
T const* relative_attention_bias;
|
||||
bool const* attention_mask;
|
||||
float const* attention_sinks;
|
||||
float const* logn_scaling_ptr;
|
||||
int const* cache_indir;
|
||||
void* context_buf;
|
||||
@ -71,6 +72,7 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
RotaryScalingType rotary_embedding_scale_type;
|
||||
float rotary_embedding_scale;
|
||||
float const* rotary_embedding_inv_freq_cache;
|
||||
float2 const* rotary_embedding_cos_sin_cache;
|
||||
float rotary_embedding_short_m_scale;
|
||||
float rotary_embedding_long_m_scale;
|
||||
int rotary_embedding_max_positions;
|
||||
@ -225,6 +227,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
xqaParams.output = generationsParams.context_buf;
|
||||
xqaParams.qkv = generationsParams.attention_input;
|
||||
xqaParams.cache_indir = generationsParams.cache_indir;
|
||||
xqaParams.attention_sinks = generationsParams.attention_sinks;
|
||||
xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant;
|
||||
xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig;
|
||||
xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths;
|
||||
@ -275,7 +278,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr;
|
||||
xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens;
|
||||
xqaParams.is_fp8_output = mFP8ContextFMHA;
|
||||
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
|
||||
xqaParams.fp8_out_scale
|
||||
= ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr);
|
||||
// Parameters required for FP4 output.
|
||||
xqaParams.output_sf = generationsParams.context_buf_sf;
|
||||
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
|
||||
@ -596,6 +600,7 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type;
|
||||
params.rotary_embedding_scale = input_params.rotary_embedding_scale;
|
||||
params.rotary_embedding_inv_freq_cache = input_params.rotary_embedding_inv_freq_cache;
|
||||
params.rotary_embedding_cos_sin_cache = input_params.rotary_embedding_cos_sin_cache;
|
||||
params.rotary_embedding_short_m_scale = input_params.rotary_embedding_short_m_scale;
|
||||
params.rotary_embedding_long_m_scale = input_params.rotary_embedding_long_m_scale;
|
||||
params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions;
|
||||
@ -620,6 +625,9 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.attention_mask = input_params.attention_mask;
|
||||
params.attention_mask_stride = input_params.attention_mask_stride;
|
||||
|
||||
// Attention sinks.
|
||||
params.attention_sinks = input_params.attention_sinks;
|
||||
|
||||
// The slope of linear position bias per head, e.g., ALiBi.
|
||||
if (input_params.linear_bias_slopes != nullptr)
|
||||
{
|
||||
@ -729,10 +737,29 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
|
||||
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
|
||||
size_t const qk_buf_float_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
|
||||
size_t const fp8_qkv_buffer_size
|
||||
= mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
|
||||
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
|
||||
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
|
||||
int const dim_v_per_head = (mMLAParams.v_head_dim);
|
||||
|
||||
// Total dimension per token across all heads for Q, K, and V components respectively
|
||||
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
|
||||
int const total_k_dim_all_heads
|
||||
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
int const total_v_dim_all_heads
|
||||
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
|
||||
int const num_total_qkv_elements
|
||||
= max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
|
||||
|
||||
size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
|
||||
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
|
||||
: 0;
|
||||
if (mFP8ContextMLA)
|
||||
{
|
||||
fp8_qkv_buffer_size
|
||||
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
|
||||
}
|
||||
|
||||
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
|
||||
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
|
||||
// Each token holds (batch_idx, token_idx_in_seq) int2.
|
||||
@ -1342,10 +1369,26 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
size_t const qk_buf_float_size = mEnableContextFMHA
|
||||
? 0
|
||||
: sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
|
||||
size_t const fp8_qkv_buffer_size
|
||||
= mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
|
||||
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
|
||||
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
|
||||
int const dim_v_per_head = (mMLAParams.v_head_dim);
|
||||
|
||||
// Total dimension per token across all heads for Q, K, and V components respectively
|
||||
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
|
||||
int const total_k_dim_all_heads
|
||||
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
int const total_v_dim_all_heads
|
||||
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
int const num_total_qkv_elements
|
||||
= params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
|
||||
size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
|
||||
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
|
||||
: 0;
|
||||
if (mFP8ContextMLA)
|
||||
{
|
||||
fp8_qkv_buffer_size
|
||||
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
|
||||
}
|
||||
size_t const padding_offset_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
|
||||
size_t const encoder_padding_offset_size
|
||||
@ -1353,8 +1396,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
// Each token holds (batch_idx, token_idx_in_seq) int2.
|
||||
size_t const tokens_info_size = sizeof(int2) * params.num_tokens;
|
||||
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
|
||||
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
|
||||
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
|
||||
size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0;
|
||||
size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0;
|
||||
|
||||
// cp workspace size upper bound
|
||||
size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1);
|
||||
@ -1601,6 +1644,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
params.mla_param->cache_type = cache_type;
|
||||
params.mla_param->cu_q_seqlens = cu_q_seqlens;
|
||||
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
|
||||
// Set BMM scales for FP8 context computation
|
||||
params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr;
|
||||
params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr;
|
||||
params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr;
|
||||
// Set additional scales for context phase
|
||||
params.mla_param->quant_scale_o = params.attention_output_orig_quant;
|
||||
params.mla_param->quant_scale_q = params.kv_scale_orig_quant;
|
||||
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
|
||||
params.mla_param->dequant_scale_q = params.kv_scale_quant_orig;
|
||||
params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig;
|
||||
params.mla_param->host_bmm1_scale
|
||||
= 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
|
||||
if (mPagedContextFMHA && mPagedKVCache)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr,
|
||||
@ -1679,8 +1734,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
// TODO: set it correctly for contiguous kv buffer (cross-attention).
|
||||
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens;
|
||||
// Device buffer pointers.
|
||||
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
|
||||
: reinterpret_cast<void const*>(attention_input);
|
||||
fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast<void const*>(fp8_qkv_buffer)
|
||||
: reinterpret_cast<void const*>(attention_input);
|
||||
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
|
||||
// TODO: add contiguous kv buffer (cross-attention).
|
||||
fmhaParams.kvPtr = nullptr;
|
||||
@ -1691,6 +1746,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
fmhaParams.outputPtr
|
||||
= mCpSize > 1 ? gatherOutBuffer : params.context_buf; // only use [totalLength, h / cpSize, Dh]
|
||||
fmhaParams.outputSfPtr = params.context_buf_sf;
|
||||
fmhaParams.attentionSinksPtr = params.attention_sinks;
|
||||
fmhaParams.packedMaskPtr = params.attention_packed_mask;
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
{
|
||||
@ -2220,6 +2276,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride;
|
||||
dispatch_params.attention_mask = params.attention_mask;
|
||||
dispatch_params.attention_mask_stride = params.attention_mask_stride;
|
||||
dispatch_params.attention_sinks = params.attention_sinks;
|
||||
dispatch_params.max_distance = max_distance;
|
||||
dispatch_params.cache_indir = params.cache_indir;
|
||||
dispatch_params.context_buf = mCpSize > 1 ? mhaOutput : params.context_buf; //
|
||||
@ -2267,6 +2324,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
|
||||
dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType;
|
||||
dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale;
|
||||
dispatch_params.rotary_embedding_inv_freq_cache = params.rotary_inv_freq;
|
||||
dispatch_params.rotary_embedding_cos_sin_cache = params.rotary_cos_sin;
|
||||
dispatch_params.rotary_embedding_short_m_scale = mRotaryEmbeddingShortMscale;
|
||||
dispatch_params.rotary_embedding_long_m_scale = mRotaryEmbeddingLongMscale;
|
||||
dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions;
|
||||
@ -2477,7 +2535,7 @@ int AttentionOp::initialize() noexcept
|
||||
}
|
||||
|
||||
// FP8 FMHA should be used with fp8 workflow together.
|
||||
if (mFP8ContextFMHA)
|
||||
if (mFP8ContextFMHA || mFP8ContextMLA)
|
||||
{
|
||||
data_type = DATA_TYPE_E4M3;
|
||||
}
|
||||
@ -2510,6 +2568,11 @@ int AttentionOp::initialize() noexcept
|
||||
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
|
||||
fmhaParams.dataTypeKv = DATA_TYPE_BF16;
|
||||
}
|
||||
if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache())
|
||||
{
|
||||
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
|
||||
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
|
||||
}
|
||||
// TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
|
||||
// bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime.
|
||||
fmhaParams.forceFp32Acc = false;
|
||||
@ -2563,7 +2626,7 @@ int AttentionOp::initialize() noexcept
|
||||
// Deepseek-V2 Generation needs a differ fmha with different argumments
|
||||
if (mIsMLAEnabled)
|
||||
{
|
||||
mEnableXQA = (mSM == kSM_120);
|
||||
mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA;
|
||||
if (mUseTllmGen)
|
||||
{
|
||||
Data_type qDataType = DATA_TYPE_FP32;
|
||||
@ -2826,6 +2889,7 @@ std::string AttentionOp::toString() const
|
||||
ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl;
|
||||
ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl;
|
||||
ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl;
|
||||
ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl;
|
||||
ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl;
|
||||
ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl;
|
||||
ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl;
|
||||
|
||||
@ -65,6 +65,8 @@ public:
|
||||
T const* qkv_bias = nullptr;
|
||||
// Attention mask input, which has shape of [batch_size, attention_mask_stride].
|
||||
bool const* attention_mask = nullptr;
|
||||
// Attention sinks with shape of [num_heads_q] float.
|
||||
float const* attention_sinks = nullptr;
|
||||
// Rotary inv_freq cache buffer to avoid re-computing.
|
||||
float const* rotary_inv_freq = nullptr;
|
||||
// Rotary cos sin cache buffer to avoid re-computing.
|
||||
@ -386,6 +388,7 @@ public:
|
||||
bool mPosShiftEnabled = false;
|
||||
bool mPagedContextFMHA = false;
|
||||
bool mFP8ContextFMHA = false;
|
||||
bool mFP8ContextMLA = false;
|
||||
bool mFP8GenerationMLA = false;
|
||||
bool mDenseContextFMHA = false;
|
||||
bool mHasFullAttentionMask = false;
|
||||
|
||||
@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE()
|
||||
return forceDeterministic;
|
||||
}
|
||||
|
||||
bool getEnvMOEDisableFinalizeFusion()
|
||||
{
|
||||
static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION");
|
||||
return moeDisableFinalizeFusion;
|
||||
}
|
||||
|
||||
bool getEnvForceDeterministicAttention()
|
||||
{
|
||||
static bool const forceDeterministic
|
||||
@ -386,7 +392,7 @@ size_t getEnvAllReduceWorkspaceSize()
|
||||
return workspaceSize;
|
||||
}
|
||||
|
||||
std::string getEnvKVCacheTransferOutputPath()
|
||||
std::string const& getEnvKVCacheTransferOutputPath()
|
||||
{
|
||||
static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or("");
|
||||
return outputPath;
|
||||
|
||||
@ -76,7 +76,7 @@ bool getEnvDisableKVCacheTransferOverlap();
|
||||
|
||||
bool getEnvEnableReceiveKVCacheParallel();
|
||||
|
||||
std::string getEnvKVCacheTransferOutputPath();
|
||||
std::string const& getEnvKVCacheTransferOutputPath();
|
||||
|
||||
bool getEnvTryZCopyForKVCacheTransfer();
|
||||
|
||||
@ -86,6 +86,9 @@ bool getEnvForceDeterministic();
|
||||
// Force deterministic behavior for MoE plugin.
|
||||
bool getEnvForceDeterministicMOE();
|
||||
|
||||
// Disable finalize fusion in MoE plugin
|
||||
bool getEnvMOEDisableFinalizeFusion();
|
||||
|
||||
// Force deterministic behavior for attention plugin.
|
||||
bool getEnvForceDeterministicAttention();
|
||||
|
||||
|
||||
@ -34,6 +34,8 @@ void fmtstr_(char const* format, fmtstr_allocator alloc, void* target, va_list a
|
||||
size_t constexpr init_size = 2048;
|
||||
char fixed_buffer[init_size];
|
||||
auto const size = std::vsnprintf(fixed_buffer, init_size, format, args0);
|
||||
va_end(args0);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(size >= 0, std::string(std::strerror(errno)));
|
||||
if (size == 0)
|
||||
{
|
||||
|
||||
@ -20,7 +20,8 @@
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::uintptr_t constexpr kCudaMemAlign = 128;
|
||||
// CuBLAS >= 12.9.1 requires 256-byte alignment.
|
||||
std::uintptr_t constexpr kCudaMemAlign = 256;
|
||||
|
||||
inline int8_t* alignPtr(int8_t* ptr, uintptr_t to)
|
||||
{
|
||||
|
||||
@ -27,6 +27,78 @@
|
||||
namespace cutlass::gemm::collective::detail
|
||||
{
|
||||
|
||||
using namespace cute;
|
||||
|
||||
typedef uint32_t __nv_fp4x8_storage_t;
|
||||
typedef uint32_t __nv_bf16x2_storage_t;
|
||||
typedef cutlass::uint128_t __nv_bf16x8_storage_t;
|
||||
|
||||
constexpr int int4_group_size = 128;
|
||||
constexpr int mxfp4_group_size = 32;
|
||||
|
||||
inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code)
|
||||
{
|
||||
unsigned res = 0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
"prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(res)
|
||||
: "r"(lo), "r"(hi), "r"(select_code));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index)
|
||||
{
|
||||
const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654
|
||||
const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210
|
||||
|
||||
__nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index);
|
||||
|
||||
return lut_res;
|
||||
}
|
||||
|
||||
__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8)
|
||||
{
|
||||
__nv_bf16x8_storage_t bf16x8_raw = {0, 0};
|
||||
__nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw);
|
||||
|
||||
unsigned zero_padding = 0x00000000U;
|
||||
|
||||
unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U;
|
||||
unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U);
|
||||
|
||||
__nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654
|
||||
__nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210
|
||||
|
||||
bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0
|
||||
bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2
|
||||
bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4
|
||||
bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6
|
||||
|
||||
__nv_bf16x2_storage_t bf16x2_0to1_bits;
|
||||
|
||||
__nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1
|
||||
__nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0
|
||||
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0
|
||||
bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits;
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2
|
||||
bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits;
|
||||
|
||||
h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5
|
||||
l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4
|
||||
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4
|
||||
bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits;
|
||||
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6
|
||||
bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits;
|
||||
|
||||
return bf16x8_raw;
|
||||
}
|
||||
|
||||
template <class Collective>
|
||||
struct MixedGroupedGemmInputUtils
|
||||
{
|
||||
@ -46,6 +118,7 @@ private:
|
||||
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||
static constexpr auto ModeHasScales = Collective::ModeHasScales;
|
||||
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
|
||||
static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable;
|
||||
|
||||
public:
|
||||
static constexpr auto elements_per_smem_scale()
|
||||
@ -239,6 +312,27 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses a lookup table to converts i4 -> 8 bit value.
|
||||
template <class EngineIn, class LayoutIn, class EngineOut,
|
||||
class LayoutOut>
|
||||
CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries
|
||||
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>&& dst)
|
||||
{
|
||||
fp4tobf16_lookup_table_convert(src, dst);
|
||||
}
|
||||
|
||||
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut>
|
||||
CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert(
|
||||
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>& dst)
|
||||
{
|
||||
|
||||
// View the input as reg
|
||||
auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0);
|
||||
auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0);
|
||||
|
||||
dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_);
|
||||
}
|
||||
|
||||
/// Utilities to dequantize A.
|
||||
template <class Layout>
|
||||
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor)
|
||||
@ -253,7 +347,6 @@ public:
|
||||
static_check_scale(flatten(Layout{}));
|
||||
}
|
||||
|
||||
// dequantize_A_kblock is here!!!
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void dequantize_A_kblock(Tensor<EngineIn, LayoutIn> const& tCrA_load,
|
||||
Tensor<EngineOut, LayoutOut>& tCrA_mma, cute::tuple<Ts...>& partitioned_extra_info, int const k_block)
|
||||
@ -288,8 +381,6 @@ public:
|
||||
}
|
||||
else if constexpr (UseScaleLookupTable)
|
||||
{
|
||||
// this path
|
||||
|
||||
constexpr int num_elements = decltype(size(src))::value;
|
||||
static_assert(is_same_v<RealSwappedElementA, cutlass::int4b_t>,
|
||||
"Lookup table only supports int4 being the quant type now.");
|
||||
@ -424,7 +515,6 @@ public:
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
using DstType = typename EngineOut::value_type;
|
||||
|
||||
Tensor src = tCrA_load(_, _, k_block);
|
||||
Tensor dst = tCrA_mma(_, _, k_block);
|
||||
@ -441,7 +531,14 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||
{
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
if constexpr (UseFP4ToBF16LookupTable)
|
||||
{
|
||||
fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
else
|
||||
{
|
||||
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,568 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*! \file
|
||||
\brief Functor performing elementwise operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/arch/copy_red_global.hpp"
|
||||
#include "cutlass_extensions/util/gather_tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/builders/sm90_builder.inl"
|
||||
#include "cutlass/epilogue/collective/builders/sm90_common.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
|
||||
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
|
||||
class CopyOpR2G>
|
||||
class EpilogueMoeFusedFinalize
|
||||
{
|
||||
public:
|
||||
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
|
||||
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
|
||||
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
|
||||
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
||||
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
||||
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
|
||||
|
||||
using ElementC = typename ThreadEpilogueOp::ElementC;
|
||||
using StrideC = StrideC_;
|
||||
using InternalStrideC = cute::remove_pointer_t<StrideC>;
|
||||
using ElementD = ElementD_;
|
||||
using StrideD = StrideD_;
|
||||
using InternalStrideD = cute::remove_pointer_t<StrideD>;
|
||||
|
||||
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
|
||||
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
|
||||
|
||||
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
|
||||
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
|
||||
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
|
||||
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
|
||||
|
||||
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
|
||||
|
||||
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
|
||||
};
|
||||
|
||||
struct TensorMapStorage
|
||||
{
|
||||
};
|
||||
|
||||
struct Arguments
|
||||
{
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
ElementC const** ptr_C{};
|
||||
StrideC dC{};
|
||||
ElementD* ptr_D{};
|
||||
StrideD dD{};
|
||||
ElementBias const* ptr_bias;
|
||||
StrideBias dBias{};
|
||||
ElementScale const* ptr_scale;
|
||||
StrideScale dScale{};
|
||||
int64_t const* group_offset{};
|
||||
int32_t const* scatter_index{};
|
||||
cutlass::FastDivmod num_rows_in_final_output;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
|
||||
{
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
|
||||
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool can_implement(
|
||||
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
bool implementable = true;
|
||||
if (problem_shape.is_host_problem_shape_available())
|
||||
{
|
||||
// Check alignment for all problem sizes
|
||||
for (int i = 0; i < problem_shape.groups(); i++)
|
||||
{
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
|
||||
}
|
||||
}
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
|
||||
"reduction instruction.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
EpilogueMoeFusedFinalize(Params const& params_)
|
||||
: params(params_)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_source_needed()
|
||||
{
|
||||
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
|
||||
return params.ptr_C != nullptr
|
||||
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
|
||||
}
|
||||
|
||||
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
|
||||
class TiledMma, class ResidueMNK>
|
||||
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
|
||||
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
|
||||
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
auto synchronize = [&]()
|
||||
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
auto N = get<1>(problem_shape_mnkl);
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
|
||||
auto mma_tile_m = tile_size<0>(tiled_mma);
|
||||
auto mma_tile_n = tile_size<1>(tiled_mma);
|
||||
auto epi_tile_m = size<0>(EpilogueTile{});
|
||||
auto epi_tile_n = size<1>(EpilogueTile{});
|
||||
|
||||
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
|
||||
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
|
||||
|
||||
// Batches are managed by using appropriate pointers to C and D matrices
|
||||
int32_t const mock_L = 1;
|
||||
int32_t const mock_l_coord = 0;
|
||||
|
||||
// Slice to get the tile this CTA is responsible for
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
||||
|
||||
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
|
||||
// we get the correct alpha/beta values for the current batch/group using group index.
|
||||
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
|
||||
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
|
||||
Tensor sD = as_position_independent_swizzle_tensor(sD_);
|
||||
|
||||
// Function to scatter output rows
|
||||
auto& num_rows = params.num_rows_in_final_output;
|
||||
auto read_scatter_map = tensorrt_llm::cutlass_extensions::IndexedGather(
|
||||
make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
|
||||
auto get_scatter_idx = [&](auto i)
|
||||
{
|
||||
auto scatter = read_scatter_map(i);
|
||||
int quot, rem;
|
||||
num_rows(quot, rem, scatter);
|
||||
return rem;
|
||||
};
|
||||
|
||||
// Represent the full output tensor
|
||||
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
|
||||
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
|
||||
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
|
||||
Tensor mD_mnl = tensorrt_llm::cutlass_extensions::make_gather_tensor(
|
||||
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
|
||||
|
||||
// Use fake shape for bias, it doesn't matter
|
||||
bool const is_bias_needed = params.ptr_bias != nullptr;
|
||||
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
|
||||
Tensor mScale_mnl = make_tensor(
|
||||
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
|
||||
|
||||
Tensor gC_mnl
|
||||
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gD_mnl
|
||||
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor gBias_mnl
|
||||
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gScale_mnl
|
||||
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
// Get the smallest tiled copy we can use to retile the accumulators
|
||||
TiledCopy tiled_copy_C_atom
|
||||
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
|
||||
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
|
||||
|
||||
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
|
||||
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
|
||||
// Make a tiled copy vectorized along major direction of D
|
||||
auto tiled_s2r = [&]()
|
||||
{
|
||||
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
|
||||
Layout<Shape<_1, Int<AlignmentD>>>{});
|
||||
}
|
||||
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
|
||||
Layout<Shape<Int<AlignmentD>, _1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
|
||||
}
|
||||
}();
|
||||
|
||||
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
|
||||
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// Allocate intermediate registers for a single subtile
|
||||
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
|
||||
// Make an identity coordinate tensor for predicating our output MN tile
|
||||
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
|
||||
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// epilogue subtile loop
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
|
||||
{
|
||||
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
|
||||
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
|
||||
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
|
||||
|
||||
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
|
||||
int r2s_v = epi_n_in_mma * size(tRS_rD);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
|
||||
{
|
||||
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
|
||||
}
|
||||
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD);
|
||||
synchronize();
|
||||
|
||||
copy(tiled_s2r, tSR_sD, tSR_rD);
|
||||
synchronize();
|
||||
|
||||
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
|
||||
|
||||
if (epilogue_op.is_source_needed())
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
};
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <class Element, class MaxVec>
|
||||
constexpr auto get_vectorized_atomic_add_op()
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
auto constexpr MaxVecSize = size(MaxVec{});
|
||||
|
||||
if constexpr (is_same_v<Element, cutlass::half_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16{};
|
||||
}
|
||||
}
|
||||
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// non-vectorized atomic add for all other types until supported
|
||||
return TypedAtomicAdd<Element>{};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class Arch, class TileShape, class ElementC, class StrideC, class ElementD, class StrideD,
|
||||
class ElementAccumulator, class ElementCompute, class ElementBias, class StrideBias, class ElementScale,
|
||||
class StrideScale>
|
||||
struct EpilogueMoeFusedFinalizeBuilder
|
||||
{
|
||||
|
||||
// assuming cooperative kernel schedule
|
||||
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
|
||||
using EpilogueTile = Shape<_128, EpiTileN>;
|
||||
|
||||
// Output of linear combination is ElementCompute instead of ElementD
|
||||
// since we will be doing more computate on it, no need to cast yet.
|
||||
using ThreadEpilogueOp
|
||||
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
using SmemLayoutAtomD
|
||||
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
|
||||
using CopyAtomR2S
|
||||
= decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator, EpilogueTile>());
|
||||
using CopyAtomS2R = DefaultCopy;
|
||||
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
|
||||
|
||||
template <class Base, class EpilogueOp>
|
||||
struct TmaWarpSpecializedAdapterWithSmemStorageImpl : Base
|
||||
{
|
||||
// We need to override this one using declaration because otherwise we double up on the smem
|
||||
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
|
||||
|
||||
// using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
TmaWarpSpecializedAdapterWithSmemStorageImpl(
|
||||
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
|
||||
: Base(params)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto load_init([[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count,
|
||||
[[maybe_unused]] int32_t sm_idx)
|
||||
{
|
||||
return cute::make_tuple(nullptr);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto store_init([[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count,
|
||||
[[maybe_unused]] int32_t sm_idx, [[maybe_unused]] int32_t warp_group_idx)
|
||||
{
|
||||
return cute::make_tuple(nullptr);
|
||||
}
|
||||
|
||||
// Dummy methods to perform different parts of TMA/Tensormap modifications
|
||||
|
||||
template <bool IsLoad, class ProblemShapeMNKL>
|
||||
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape,
|
||||
[[maybe_unused]] int32_t next_batch, [[maybe_unused]] int32_t warp_group_idx)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t warp_group_idx)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class EpilogueOp>
|
||||
using TmaWarpSpecializedAdapterWithSmemStorage = TmaWarpSpecializedAdapterWithSmemStorageImpl<
|
||||
std::conditional_t<Arch::kMinComputeCapability >= 100, detail::Sm100TmaWarpSpecializedAdapter<EpilogueOp>,
|
||||
detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>>,
|
||||
EpilogueOp>;
|
||||
|
||||
using CollectiveOp = TmaWarpSpecializedAdapterWithSmemStorage<
|
||||
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
|
||||
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace collective
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,547 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/epilogue/fusion/operations.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
#include "cutlass_extensions/arch/copy_red_global.hpp"
|
||||
#include "cutlass_extensions/util/gather_tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
template <
|
||||
class EpilogueTile,
|
||||
class StrideOutput,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S,
|
||||
class ElementOutput,
|
||||
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct Sm90ScatterPtrArray {
|
||||
|
||||
using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{})))));
|
||||
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{}));
|
||||
|
||||
using ElementIndex = int32_t;
|
||||
// TODO: more generic treatment, or pass StrideIndex via template param?
|
||||
using StrideIndex = conditional_t<cutlass::gemm::detail::is_mn_major<StrideOutput>(), Stride<_0,_1,_0>, Stride<_1,_0,_0>>;
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
struct Arguments {
|
||||
ElementOutput* ptr_out = nullptr;
|
||||
StrideOutput dOut = {};
|
||||
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
|
||||
int index_modulo{}; // modulo used to transform the index before store
|
||||
bool use_reduction = true;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ElementOutput* ptr_out = nullptr;
|
||||
StrideOutput dOut = {};
|
||||
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
|
||||
cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store
|
||||
bool use_reduction = true;
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return {
|
||||
args.ptr_out,
|
||||
args.dOut,
|
||||
args.ptr_index,
|
||||
cutlass::FastDivmod(args.index_modulo),
|
||||
args.use_reduction
|
||||
};
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ScatterPtrArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<
|
||||
class ArgsTuple
|
||||
>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(ArgsTuple&& args_tuple)
|
||||
: args_tuple(std::move(args_tuple)) {}
|
||||
|
||||
ArgsTuple args_tuple;
|
||||
|
||||
template <typename ElementAccumulator, typename ElementInput, int FragmentSize>
|
||||
CUTLASS_DEVICE auto
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n,
|
||||
Array<ElementInput, FragmentSize> const& frg_input) {
|
||||
|
||||
auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple;
|
||||
|
||||
using ConvertInput = NumericArrayConverter<ElementOutput, ElementInput, FragmentSize, RoundStyle>;
|
||||
ConvertInput convert_input{};
|
||||
|
||||
Tensor tC_rOut_frg = recast<Array<ElementOutput, FragmentSize>>(coalesce(tC_rOut)); // (EPI_V)
|
||||
tC_rOut_frg(epi_v) = convert_input(frg_input);
|
||||
|
||||
return tC_rOut_frg(epi_v);
|
||||
}
|
||||
|
||||
template <class STensor, class SyncFn, class VTensor>
|
||||
CUTLASS_DEVICE void
|
||||
reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) {
|
||||
|
||||
auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple;
|
||||
|
||||
Tensor byte_buffer = recast<uint8_t>(reduction_buffer);
|
||||
static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v<uint8_t> >= cosize(SmemLayout{}) * sizeof_bits_v<ElementOutput>,
|
||||
"Not enough space in scratch smem buffer");
|
||||
|
||||
Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr<ElementOutput>(byte_buffer.data())), SmemLayout{}));
|
||||
|
||||
auto thread_r2s = tiled_r2s.get_slice(thread_idx);
|
||||
Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut);
|
||||
Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut);
|
||||
|
||||
auto thread_r2g = tiled_r2g_red.get_slice(thread_idx);
|
||||
Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n);
|
||||
Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut);
|
||||
Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers
|
||||
|
||||
// sanity check for register reuse
|
||||
CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G");
|
||||
|
||||
copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi);
|
||||
sync_fn();
|
||||
copy(tRG_sOut_epi, tRG_rOut_epi);
|
||||
|
||||
auto residue = residue_cD; // capturing structured bindings is a C++20 feature
|
||||
Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n);
|
||||
auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); });
|
||||
|
||||
if (use_reduction) {
|
||||
copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi);
|
||||
}
|
||||
else {
|
||||
copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Element, int MaxVecSize>
|
||||
static constexpr auto get_reduction_op()
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// For now only support red.add
|
||||
if constexpr (is_same_v<Element, cutlass::half_t>) {
|
||||
if constexpr (MaxVecSize % 8 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 4 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 2 == 0) {
|
||||
return SM70_RED_ADD_NOFTZ_F16x2{};
|
||||
}
|
||||
else {
|
||||
return SM70_RED_ADD_NOFTZ_F16{};
|
||||
}
|
||||
}
|
||||
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>) {
|
||||
if constexpr (MaxVecSize % 8 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 4 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 2 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2{};
|
||||
}
|
||||
else {
|
||||
return SM90_RED_ADD_NOFTZ_BF16{};
|
||||
}
|
||||
}
|
||||
else {
|
||||
// non-vectorized atomic add for all other types until supported
|
||||
return TypedAtomicAdd<Element>{};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); };
|
||||
Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1)
|
||||
Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N)
|
||||
Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1)
|
||||
Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N)
|
||||
Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor tC_gOut = sm90_partition_for_epilogue<ReferenceSrc>(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Tensor tC_rOut = make_tensor<ElementOutput>(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
auto tiled_r2s = conditional_return<ReferenceSrc>(
|
||||
make_tiled_copy_S(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy),
|
||||
make_tiled_copy_D(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy)
|
||||
);
|
||||
|
||||
// Vectorization must not exceed alignment and also the number of values per thread in the tile
|
||||
int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy));
|
||||
int constexpr NumValTile = product(take<0,2>(shape(cD_epi)));
|
||||
int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads);
|
||||
|
||||
// Choose the largest available red.global op and an st.global op with matching vectorization
|
||||
using CopyOpR2GRed = decltype(get_reduction_op<ElementOutput, MaxVecSize>());
|
||||
using CopyOpR2GStg = UniversalCopy<uint_bit_t<Copy_Atom<CopyOpR2GRed,ElementOutput>::NumValSrc * sizeof_bits_v<ElementOutput>>>;
|
||||
|
||||
auto make_tiled_r2g = [&](auto copy_op)
|
||||
{
|
||||
using CopyAtomR2G = Copy_Atom<decltype(copy_op),ElementOutput>;
|
||||
constexpr int VecSize = CopyAtomR2G::NumValSrc;
|
||||
if constexpr (cutlass::gemm::detail::is_k_major<StrideOutput>()) {
|
||||
constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize;
|
||||
constexpr int ThreadsMinor = NumThreads / ThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomR2G{},
|
||||
Layout<Shape<Int<ThreadsMinor>, Int<ThreadsMajor>>, Stride<Int<ThreadsMajor>, _1>>{},
|
||||
Layout<Shape<_1, Int<VecSize>>>{});
|
||||
}
|
||||
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideOutput>()) {
|
||||
constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize;
|
||||
constexpr int ThreadsMinor = NumThreads / ThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomR2G{},
|
||||
Layout<Shape<Int<ThreadsMajor>, Int<ThreadsMinor>>, Stride<_1, Int<ThreadsMajor>>>{},
|
||||
Layout<Shape<Int<VecSize>, _1>>{});
|
||||
}
|
||||
else {
|
||||
static_assert(cute::is_void_v<StrideOutput>, "Unsupported D gmem layout.");
|
||||
}
|
||||
};
|
||||
|
||||
auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{});
|
||||
auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{});
|
||||
|
||||
// Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy,
|
||||
// ensure they have matching layouts/tilers
|
||||
using TiledR2GRed = decltype(tiled_r2g_red);
|
||||
using TiledR2GStg = decltype(tiled_r2g_stg);
|
||||
static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc");
|
||||
static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst");
|
||||
static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV");
|
||||
static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN");
|
||||
|
||||
auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx);
|
||||
Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)
|
||||
Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)
|
||||
|
||||
auto args_tuple = make_tuple(
|
||||
cute::move(tC_rOut),
|
||||
tiled_r2s,
|
||||
tRG_gOut,
|
||||
tRG_cD,
|
||||
tiled_r2g_red,
|
||||
tiled_r2g_stg,
|
||||
params_ptr->use_reduction,
|
||||
args.thread_idx,
|
||||
args.residue_cD);
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple));
|
||||
}
|
||||
};
|
||||
|
||||
template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentBias_ = 128 / cute::sizeof_bits_v<ElementBias_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct ScaledAccPerRowBias
|
||||
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_>
|
||||
{
|
||||
using ElementBias = ElementBias_;
|
||||
static constexpr int AlignmentBias = AlignmentBias_;
|
||||
static constexpr bool IsPerRowBiasSupported = true;
|
||||
};
|
||||
|
||||
template<
|
||||
class GmemLayoutTagOut,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementScale = ElementCompute,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
|
||||
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct ScaledAccPerRowBiasPerColScaleScatter
|
||||
: ScaledAccPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
|
||||
{
|
||||
using ElementAux = ElementOutput;
|
||||
using GmemLayoutTagAux = GmemLayoutTagOut;
|
||||
static constexpr int AlignmentAux = AlignmentOutput;
|
||||
static constexpr bool IsAuxOutSupported = true;
|
||||
};
|
||||
|
||||
// D = alpha * acc + per-row bias
|
||||
template<
|
||||
class CtaTileShapeMNK,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90ScaledAccPerRowBiasPtrArray =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int64_t>>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias
|
||||
>;
|
||||
|
||||
template<
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile,
|
||||
class StrideOutput,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementScale = ElementCompute,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
|
||||
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray =
|
||||
Sm90EVT<Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>, // scatter store
|
||||
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // scale * (alpha * acc + bias)
|
||||
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale
|
||||
Sm90ScaledAccPerRowBiasPtrArray<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> // alpha * acc + bias
|
||||
>
|
||||
>;
|
||||
|
||||
template <
|
||||
int StagesC,
|
||||
int StagesD,
|
||||
int FragmentSize,
|
||||
bool ReuseSmemC,
|
||||
bool DelayTmaStore,
|
||||
int NumEpilogueWarpGroups,
|
||||
class GmemLayoutTagOut,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementScale,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
int AlignmentOutput,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC,
|
||||
StagesD,
|
||||
FragmentSize,
|
||||
ReuseSmemC,
|
||||
DelayTmaStore,
|
||||
NumEpilogueWarpGroups
|
||||
>,
|
||||
fusion::ScaledAccPerRowBiasPerColScaleScatter<GmemLayoutTagOut,
|
||||
ElementOutput,
|
||||
ElementCompute,
|
||||
ElementBias,
|
||||
ElementScale,
|
||||
ElementScalar,
|
||||
AlignmentBias,
|
||||
AlignmentOutput,
|
||||
RoundStyle>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
SmemLayoutAtom,
|
||||
CopyOpR2S
|
||||
> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray<
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>,
|
||||
SmemLayoutAtom, CopyOpR2S,
|
||||
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
|
||||
AlignmentBias, AlignmentOutput, RoundStyle
|
||||
> {
|
||||
|
||||
using StrideOutput = cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>;
|
||||
|
||||
using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray<
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
StrideOutput,
|
||||
SmemLayoutAtom, CopyOpR2S,
|
||||
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
|
||||
AlignmentBias, AlignmentOutput, RoundStyle
|
||||
>;
|
||||
using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter<
|
||||
GmemLayoutTagOut,
|
||||
ElementOutput,
|
||||
ElementCompute,
|
||||
ElementBias,
|
||||
ElementScale,
|
||||
ElementScalar,
|
||||
AlignmentBias,
|
||||
AlignmentOutput,
|
||||
RoundStyle>;
|
||||
|
||||
struct Arguments {
|
||||
|
||||
using StrideAlpha = Stride<_0,_0,int64_t>;
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar const* alpha_ptr{};
|
||||
ElementScalar const* const* alpha_ptr_array{};
|
||||
StrideAlpha dAlpha{};
|
||||
|
||||
using StrideBias = Stride<_1,_0,int64_t>;
|
||||
ElementBias const* const* bias_ptr{};
|
||||
StrideBias dBias{};
|
||||
|
||||
using StrideScale = Stride<_0,_1,int64_t>;
|
||||
ElementScalar const* const* scale_ptr_array{};
|
||||
StrideScale dScale{};
|
||||
|
||||
// Nested args not usable due to a compiler bug with constexpr evaluation
|
||||
// using ScatterArguments = typename Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>::Arguments;
|
||||
// ScatterArguments scatter{};
|
||||
|
||||
ElementOutput* ptr_out = nullptr;
|
||||
StrideOutput dOut = {};
|
||||
int const* const* ptr_index{}; // per-group pointer to the scatter index
|
||||
int index_modulo{}; // modulo used to transform the index before store
|
||||
bool use_reduction = true;
|
||||
|
||||
operator typename Impl::Arguments() const {
|
||||
return
|
||||
{ // unary op: reduce(scale * (beta * C + (alpha * acc)))
|
||||
{ // binary op: scale * (beta * C + (alpha * acc))
|
||||
{ scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast
|
||||
{ // ternary op : alpha * acc + bias
|
||||
{{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha
|
||||
{}, // leaf args : acc
|
||||
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
|
||||
{} // ternary args : multiply_add
|
||||
}, // end binary op
|
||||
{} // binary args: multiply
|
||||
}, // end binary op
|
||||
//scatter // unary args: reduce
|
||||
{ ptr_out, dOut, ptr_index, index_modulo, use_reduction }
|
||||
}; // end unary op
|
||||
}
|
||||
};
|
||||
|
||||
// Ctor inheritance
|
||||
using Impl::Impl;
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::epilogue::fusion
|
||||
|
||||
// clang-format on
|
||||
@ -30,37 +30,12 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#define GROUP_SIZE 128
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE void warpgroup_wait_()
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
||||
cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N);
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group<N> without CUTE_ARCH_MMA_SM90A_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void warpgroup_wait_dispatch(int onthefly_count)
|
||||
{
|
||||
switch (onthefly_count)
|
||||
{
|
||||
case 0: warpgroup_wait_<0>(); break;
|
||||
case 4: warpgroup_wait_<4>(); break;
|
||||
case 8: warpgroup_wait_<8>(); break;
|
||||
case 12: warpgroup_wait_<12>(); break;
|
||||
default: assert(false && "Invalid onthefly_count value");
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
@ -91,7 +66,7 @@ public:
|
||||
private:
|
||||
template <class T>
|
||||
friend struct detail::MixedGroupedGemmInputUtils;
|
||||
using CollectiveType = CollectiveMma<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_,
|
||||
using CollectiveType = CollectiveMmaArrayMixedInput<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_,
|
||||
ElementBOptionalTuple, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_>;
|
||||
using Utils = detail::MixedGroupedGemmInputUtils<CollectiveType>;
|
||||
@ -146,6 +121,11 @@ public:
|
||||
static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
|
||||
"Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled].");
|
||||
|
||||
static constexpr bool IsMXFP4 = cute::is_same_v<ElementA, cutlass::float_e2m1_t>;
|
||||
// Group size 128 for int4 weights
|
||||
// Group size 32 for mxfp4 weights
|
||||
static constexpr int ScalingGroupSize = IsMXFP4 ? detail::mxfp4_group_size : detail::int4_group_size;
|
||||
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
@ -268,6 +248,8 @@ public:
|
||||
|| KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;
|
||||
static constexpr bool UseScaleLookupTable
|
||||
= KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v<ElementScale>;
|
||||
static constexpr bool UseFP4ToBF16LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale
|
||||
&& cute::is_same_v<ElementA, cutlass::float_e2m1_t> && cute::is_same_v<ElementB, cutlass::bfloat16_t>;
|
||||
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{});
|
||||
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});
|
||||
static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);
|
||||
@ -705,7 +687,7 @@ public:
|
||||
{
|
||||
// The real scale_k that actually works
|
||||
// auto scale_k = K / mainloop_params.chunk_size;
|
||||
auto scale_k = K / GROUP_SIZE;
|
||||
auto scale_k = K / ScalingGroupSize;
|
||||
|
||||
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l)
|
||||
Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
|
||||
@ -872,7 +854,6 @@ public:
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
|
||||
{
|
||||
// zero copy
|
||||
auto tZgZ = get<2>(extra_input_partitions);
|
||||
auto tZsZ = get<3>(extra_input_partitions);
|
||||
if (cute::elect_one_sync())
|
||||
@ -979,7 +960,8 @@ public:
|
||||
return make_tensor_like<RealSwappedElementA>(tCsA(_, _, _, Int<0>{}));
|
||||
}
|
||||
}();
|
||||
Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
// tCrB is just a view of the tensor tCsB
|
||||
Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
//
|
||||
@ -1013,8 +995,8 @@ public:
|
||||
|
||||
multiply_add<ElementAccumulator> fma;
|
||||
|
||||
constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())();
|
||||
constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE;
|
||||
constexpr int NumMMAsPerChunk = ScalingGroupSize / cute::get<0, 1>(tCsB.shape())();
|
||||
constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / ScalingGroupSize;
|
||||
cute::array<decltype(make_fragment_like(accum)), NumChunksPerTileK> intermediate_array;
|
||||
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA_load);
|
||||
@ -1045,8 +1027,6 @@ public:
|
||||
// src: tCrA_load, dst: tCrA_mma
|
||||
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id)
|
||||
@ -1079,10 +1059,11 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_wait<0>();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_)
|
||||
{
|
||||
warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk);
|
||||
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
||||
|
||||
// Apply the group-wise scaling
|
||||
@ -1129,7 +1110,6 @@ public:
|
||||
Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info,
|
||||
copy_partitions_extra_info, 1, smem_pipe_read.index());
|
||||
|
||||
warpgroup_wait<K_WAIT_MAX>();
|
||||
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
||||
}
|
||||
}
|
||||
@ -1169,8 +1149,6 @@ public:
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<K_WAIT_MAX>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage,
|
||||
// so we can release prior barrier
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
pipeline.consumer_release(
|
||||
@ -1187,10 +1165,11 @@ public:
|
||||
{
|
||||
// The last k_block
|
||||
|
||||
warpgroup_wait<0>();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_)
|
||||
{
|
||||
warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk);
|
||||
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
||||
|
||||
// Apply the group-wise scaling
|
||||
@ -1257,7 +1236,6 @@ public:
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<K_WAIT_MAX>();
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
// release prior barrier
|
||||
@ -1318,7 +1296,7 @@ public:
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
// warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count)
|
||||
{
|
||||
@ -1462,7 +1440,7 @@ public:
|
||||
{
|
||||
NonVoidElementScale const* ptr_S = nullptr;
|
||||
// auto scale_k = K / mainloop_params.chunk_size;
|
||||
auto scale_k = K / GROUP_SIZE;
|
||||
auto scale_k = K / ScalingGroupSize;
|
||||
Tensor tensor_scale = make_tensor(
|
||||
detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
|
||||
cute::detail::fill_tma_gmem_shape_stride(
|
||||
@ -1472,7 +1450,7 @@ public:
|
||||
{
|
||||
ElementZero const* ptr_Z = nullptr;
|
||||
// auto scale_k = K / mainloop_params.chunk_size;
|
||||
auto scale_k = K / GROUP_SIZE;
|
||||
auto scale_k = K / ScalingGroupSize;
|
||||
Tensor tensor_zero = make_tensor(
|
||||
detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
|
||||
cute::detail::fill_tma_gmem_shape_stride(
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/print.hpp"
|
||||
|
||||
namespace tensorrt_llm::cutlass_extensions
|
||||
namespace cutlass::util
|
||||
{
|
||||
|
||||
/// Function object that applies an index to its argument
|
||||
@ -81,7 +81,7 @@ struct CustomStride
|
||||
template <class Div>
|
||||
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
|
||||
{
|
||||
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
|
||||
return CustomStride<Func, decltype(cute::safe_div(s.stride_, div))>(s.func_, cute::safe_div(s.stride_, div));
|
||||
}
|
||||
|
||||
// Circumvent the requirement on make_layout that shape and stride are integral
|
||||
@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S
|
||||
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
|
||||
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
|
||||
}
|
||||
} // namespace tensorrt_llm::cutlass_extensions
|
||||
} // namespace cutlass::util
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01)
|
||||
set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72)
|
||||
set(NVSHMEM_URL_HASH
|
||||
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
|
||||
|
||||
@ -19,8 +19,15 @@ foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
|
||||
set(CUDA_ARCH_MINOR ${CMAKE_MATCH_2})
|
||||
set(CUDA_ARCH_POSTFIX ${CMAKE_MATCH_3})
|
||||
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9)
|
||||
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
|
||||
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}")
|
||||
# The FP4-related conversion instructions in DeepEP require SM100a, SM110a,
|
||||
# or SM120a.
|
||||
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0)
|
||||
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
|
||||
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}a${CUDA_ARCH_POSTFIX}")
|
||||
else()
|
||||
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
|
||||
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -132,10 +132,12 @@ std::optional<std::shared_ptr<KVCacheEventManager>> Executor::getKVCacheEventMan
|
||||
return mImpl->getKVCacheEventManager();
|
||||
}
|
||||
|
||||
KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize)
|
||||
KVCacheEvent::KVCacheEvent(
|
||||
size_t eventId, KVCacheEventData data, SizeType32 windowSize, std::optional<SizeType32> attentionDpRank)
|
||||
: eventId{eventId}
|
||||
, data{std::move(data)}
|
||||
, windowSize{windowSize}
|
||||
, attentionDpRank{attentionDpRank}
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
|
||||
std::optional<size_t> const& hostCacheSize, bool onboardBlocks,
|
||||
std::optional<FloatType> const& crossKvCacheFraction, std::optional<RetentionPriority> secondaryOffloadMinPriority,
|
||||
size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm,
|
||||
SizeType32 attentionDpEventsGatherPeriodMs,
|
||||
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults)
|
||||
: mEnableBlockReuse(enableBlockReuse)
|
||||
, mHostCacheSize(hostCacheSize)
|
||||
@ -36,6 +37,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
|
||||
, mEnablePartialReuse{enablePartialReuse}
|
||||
, mCopyOnPartialReuse{copyOnPartialReuse}
|
||||
, mUseUvm{useUvm}
|
||||
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
|
||||
{
|
||||
if (maxTokens)
|
||||
{
|
||||
@ -61,6 +63,8 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
|
||||
{
|
||||
fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value());
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0");
|
||||
}
|
||||
|
||||
bool KvCacheConfig::getEnableBlockReuse() const
|
||||
@ -128,6 +132,11 @@ bool KvCacheConfig::getUseUvm() const
|
||||
return mUseUvm;
|
||||
}
|
||||
|
||||
SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const
|
||||
{
|
||||
return mAttentionDpEventsGatherPeriodMs;
|
||||
}
|
||||
|
||||
void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse)
|
||||
{
|
||||
mEnableBlockReuse = enableBlockReuse;
|
||||
@ -204,6 +213,12 @@ void KvCacheConfig::setUseUvm(bool useUvm)
|
||||
mUseUvm = useUvm;
|
||||
}
|
||||
|
||||
void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs)
|
||||
{
|
||||
TLLM_CHECK(attentionDpEventsGatherPeriodMs > 0);
|
||||
mAttentionDpEventsGatherPeriodMs = attentionDpEventsGatherPeriodMs;
|
||||
}
|
||||
|
||||
void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults)
|
||||
{
|
||||
if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec)
|
||||
|
||||
@ -27,26 +27,29 @@ LoraConfig::LoraConfig(IdType taskId, std::optional<Tensor> weights, std::option
|
||||
, mWeights(std::move(weights))
|
||||
, mConfig(std::move(config))
|
||||
{
|
||||
if (mWeights.has_value() || mConfig.has_value())
|
||||
if (mConfig.has_value())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mWeights.has_value() && mConfig.has_value(),
|
||||
"Request for LoRA inference must have both lora weights and lora config");
|
||||
|
||||
SizeType32 constexpr expectedWeightsDims = 2;
|
||||
SizeType32 constexpr expectedConfigDims = 2;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions");
|
||||
TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU
|
||||
&& mConfig.value().getMemoryType() != MemoryType::kUNKNOWN,
|
||||
"Expected lora config to be in CPU memory");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32");
|
||||
}
|
||||
if (mWeights.has_value())
|
||||
{
|
||||
SizeType32 constexpr expectedWeightsDims = 2;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mConfig.has_value(), "Request for LoRA inference with lora weights must also have lora config");
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions");
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU
|
||||
&& mWeights.value().getMemoryType() != MemoryType::kUNKNOWN,
|
||||
"Expected lora weights to be in CPU memory");
|
||||
TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU
|
||||
&& mConfig.value().getMemoryType() != MemoryType::kUNKNOWN,
|
||||
"Expected lora weights to be in CPU memory");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32");
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mConfig.value().getShape()[0] == mWeights.value().getShape()[0],
|
||||
"Expected dim 0 of lora weights and lora config to have the same size");
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
#include "tensorrt_llm/executor/serializeUtils.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
@ -1162,10 +1163,11 @@ KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is)
|
||||
auto secondaryOffloadMinPriority = su::deserialize<std::optional<executor::RetentionPriority>>(is);
|
||||
auto eventBufferMaxSize = su::deserialize<size_t>(is);
|
||||
auto useUvm = su::deserialize<bool>(is);
|
||||
auto attentionDpEventsGatherPeriodMs = su::deserialize<SizeType32>(is);
|
||||
|
||||
return KvCacheConfig{enableBlockReuse, maxTokens, maxAttentionWindowVec, sinkTokenLength, freeGpuMemoryFraction,
|
||||
hostCacheSize, onboardBlocks, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize,
|
||||
enablePartialReuse, copyOnPartialReuse, useUvm};
|
||||
enablePartialReuse, copyOnPartialReuse, useUvm, attentionDpEventsGatherPeriodMs};
|
||||
}
|
||||
|
||||
void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os)
|
||||
@ -1183,6 +1185,7 @@ void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream&
|
||||
su::serialize(kvCacheConfig.getSecondaryOffloadMinPriority(), os);
|
||||
su::serialize(kvCacheConfig.getEventBufferMaxSize(), os);
|
||||
su::serialize(kvCacheConfig.getUseUvm(), os);
|
||||
su::serialize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), os);
|
||||
}
|
||||
|
||||
size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig)
|
||||
@ -1202,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig)
|
||||
totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority());
|
||||
totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize());
|
||||
totalSize += su::serializedSize(kvCacheConfig.getUseUvm());
|
||||
totalSize += su::serializedSize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs());
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
@ -2181,6 +2185,237 @@ std::vector<RequestStatsPerIteration> Serialization::deserializeRequestStatsPerI
|
||||
return iterRequestStatsVec;
|
||||
}
|
||||
|
||||
// KVCacheEvents deque
|
||||
std::vector<char> Serialization::serialize(std::deque<KVCacheEvent> const& eventQueue)
|
||||
{
|
||||
// Compute the size of serialized buffer
|
||||
size_t totalSize = 0;
|
||||
totalSize += sizeof(size_t);
|
||||
for (auto const& event : eventQueue)
|
||||
{
|
||||
totalSize += su::serializedSize(event);
|
||||
}
|
||||
|
||||
std::vector<char> buffer(totalSize);
|
||||
std::stringbuf strbuf(std::ios_base::out | std::ios_base::in);
|
||||
strbuf.pubsetbuf(buffer.data(), buffer.size());
|
||||
std::ostream os(&strbuf);
|
||||
|
||||
su::serialize(eventQueue.size(), os);
|
||||
for (auto const& event : eventQueue)
|
||||
{
|
||||
su::serialize(event, os);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::deque<KVCacheEvent> Serialization::deserializeKVCacheEvents(std::vector<char>& buffer)
|
||||
{
|
||||
std::deque<KVCacheEvent> kvCacheEvents;
|
||||
su::VectorWrapBuf<char> strbuf(buffer);
|
||||
std::istream is(&strbuf);
|
||||
auto numEvents = su::deserialize<std::size_t>(is);
|
||||
for (std::size_t event = 0; event < numEvents; ++event)
|
||||
{
|
||||
kvCacheEvents.emplace_back(Serialization::deserializeKVCacheEvent(is));
|
||||
}
|
||||
return kvCacheEvents;
|
||||
}
|
||||
|
||||
// KVCacheEvent
|
||||
size_t Serialization::serializedSize(KVCacheEvent const& event)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(event.eventId);
|
||||
totalSize += su::serializedSize(event.data);
|
||||
totalSize += su::serializedSize(event.windowSize);
|
||||
totalSize += su::serializedSize(event.attentionDpRank);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void Serialization::serialize(KVCacheEvent const& event, std::ostream& os)
|
||||
{
|
||||
su::serialize(event.eventId, os);
|
||||
su::serialize(event.data, os);
|
||||
su::serialize(event.windowSize, os);
|
||||
su::serialize(event.attentionDpRank, os);
|
||||
}
|
||||
|
||||
KVCacheEvent Serialization::deserializeKVCacheEvent(std::istream& is)
|
||||
{
|
||||
auto eventId = su::deserialize<IdType>(is);
|
||||
auto data = su::deserialize<KVCacheEventData>(is);
|
||||
auto windowSize = su::deserialize<SizeType32>(is);
|
||||
auto attentionDpRank = su::deserialize<std::optional<SizeType32>>(is);
|
||||
|
||||
return KVCacheEvent{eventId, data, windowSize, attentionDpRank};
|
||||
}
|
||||
|
||||
// KVCacheCreatedData
|
||||
size_t Serialization::serializedSize(KVCacheCreatedData const& data)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(data.numBlocksPerCacheLevel);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void Serialization::serialize(KVCacheCreatedData const& data, std::ostream& os)
|
||||
{
|
||||
su::serialize(data.numBlocksPerCacheLevel, os);
|
||||
}
|
||||
|
||||
KVCacheCreatedData Serialization::deserializeKVCacheCreatedData(std::istream& is)
|
||||
{
|
||||
auto numBlocksPerCacheLevel = su::deserialize<std::vector<SizeType32>>(is);
|
||||
return KVCacheCreatedData{numBlocksPerCacheLevel};
|
||||
}
|
||||
|
||||
// KVCacheStoredData
|
||||
size_t Serialization::serializedSize(KVCacheStoredData const& data)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(data.parentHash);
|
||||
totalSize += su::serializedSize(data.blocks);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void Serialization::serialize(KVCacheStoredData const& data, std::ostream& os)
|
||||
{
|
||||
su::serialize(data.parentHash, os);
|
||||
su::serialize(data.blocks, os);
|
||||
}
|
||||
|
||||
KVCacheStoredData Serialization::deserializeKVCacheStoredData(std::istream& is)
|
||||
{
|
||||
auto parentHash = su::deserialize<std::optional<IdType>>(is);
|
||||
auto blocks = su::deserialize<std::vector<KVCacheStoredBlockData>>(is);
|
||||
return KVCacheStoredData{parentHash, blocks};
|
||||
}
|
||||
|
||||
// KVCacheStoredBlockData
|
||||
size_t Serialization::serializedSize(KVCacheStoredBlockData const& data)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(data.blockHash);
|
||||
totalSize += su::serializedSize(data.tokens);
|
||||
totalSize += su::serializedSize(data.loraId);
|
||||
totalSize += su::serializedSize(data.cacheLevel);
|
||||
totalSize += su::serializedSize(data.priority);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream& os)
|
||||
{
|
||||
su::serialize(data.blockHash, os);
|
||||
su::serialize(data.tokens, os);
|
||||
su::serialize(data.loraId, os);
|
||||
su::serialize(data.cacheLevel, os);
|
||||
su::serialize(data.priority, os);
|
||||
}
|
||||
|
||||
KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is)
|
||||
{
|
||||
auto blockHash = su::deserialize<IdType>(is);
|
||||
auto tokens = su::deserialize<tensorrt_llm::runtime::VecUniqueTokens>(is);
|
||||
auto loraId = su::deserialize<std::optional<tensorrt_llm::runtime::LoraTaskIdType>>(is);
|
||||
auto cacheLevel = su::deserialize<SizeType32>(is);
|
||||
auto priority = su::deserialize<SizeType32>(is);
|
||||
|
||||
return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority};
|
||||
}
|
||||
|
||||
// KVcacheRemovedData
|
||||
|
||||
size_t Serialization::serializedSize(KVCacheRemovedData const& data)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(data.blockHashes);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void Serialization::serialize(KVCacheRemovedData const& data, std::ostream& os)
|
||||
{
|
||||
su::serialize(data.blockHashes, os);
|
||||
}
|
||||
|
||||
KVCacheRemovedData Serialization::deserializeKVCacheRemovedData(std::istream& is)
|
||||
{
|
||||
auto blockHashes = su::deserialize<std::vector<IdType>>(is);
|
||||
return KVCacheRemovedData{blockHashes};
|
||||
}
|
||||
|
||||
// KVCacheEventDiff
|
||||
template <typename T>
|
||||
size_t Serialization::serializedSize(KVCacheEventDiff<T> const& data)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(data.oldValue);
|
||||
totalSize += su::serializedSize(data.newValue);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Serialization::serialize(KVCacheEventDiff<T> const& data, std::ostream& os)
|
||||
{
|
||||
su::serialize(data.oldValue, os);
|
||||
su::serialize(data.newValue, os);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
KVCacheEventDiff<T> Serialization::deserializeKVCacheEventDiff(std::istream& is)
|
||||
{
|
||||
auto oldValue = su::deserialize<T>(is);
|
||||
auto newValue = su::deserialize<T>(is);
|
||||
return KVCacheEventDiff<T>{oldValue, newValue};
|
||||
}
|
||||
|
||||
// KVCacheUpdatedData
|
||||
size_t Serialization::serializedSize(KVCacheUpdatedData const& data)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(data.blockHash);
|
||||
totalSize += su::serializedSize(data.cacheLevel);
|
||||
totalSize += su::serializedSize(data.priority);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void Serialization::serialize(KVCacheUpdatedData const& data, std::ostream& os)
|
||||
{
|
||||
su::serialize(data.blockHash, os);
|
||||
su::serialize(data.cacheLevel, os);
|
||||
su::serialize(data.priority, os);
|
||||
}
|
||||
|
||||
KVCacheUpdatedData Serialization::deserializeKVCacheUpdatedData(std::istream& is)
|
||||
{
|
||||
auto blockHash = su::deserialize<IdType>(is);
|
||||
auto cacheLevel = su::deserialize<std::optional<KVCacheEventDiff<SizeType32>>>(is);
|
||||
auto priority = su::deserialize<std::optional<KVCacheEventDiff<SizeType32>>>(is);
|
||||
return KVCacheUpdatedData{blockHash, cacheLevel, priority};
|
||||
}
|
||||
|
||||
// UniqueToken
|
||||
size_t Serialization::serializedSize(tensorrt_llm::runtime::UniqueToken const& token)
|
||||
{
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(token.tokenId);
|
||||
totalSize += su::serializedSize(token.tokenExtraId);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void Serialization::serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os)
|
||||
{
|
||||
su::serialize(token.tokenId, os);
|
||||
su::serialize(token.tokenExtraId, os);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::UniqueToken Serialization::deserializeUniqueToken(std::istream& is)
|
||||
{
|
||||
auto tokenId = su::deserialize<tensorrt_llm::runtime::TokenIdType>(is);
|
||||
auto tokenExtraId = su::deserialize<tensorrt_llm::runtime::TokenExtraIdType>(is);
|
||||
return tensorrt_llm::runtime::UniqueToken{tokenId, tokenExtraId};
|
||||
}
|
||||
|
||||
// String
|
||||
std::string Serialization::deserializeString(std::istream& is)
|
||||
{
|
||||
|
||||
@ -122,6 +122,14 @@ static_assert(hasSerializedSize<GuidedDecodingParams>(size_t()));
|
||||
static_assert(!hasSerializedSize<std::string>(size_t()));
|
||||
static_assert(!hasSerializedSize<std::optional<float>>(size_t()));
|
||||
static_assert(hasSerializedSize<CacheTransceiverConfig>(size_t()));
|
||||
static_assert(hasSerializedSize<KVCacheEvent>(size_t()));
|
||||
static_assert(hasSerializedSize<KVCacheCreatedData>(size_t()));
|
||||
static_assert(hasSerializedSize<KVCacheStoredData>(size_t()));
|
||||
static_assert(hasSerializedSize<KVCacheStoredBlockData>(size_t()));
|
||||
static_assert(hasSerializedSize<KVCacheRemovedData>(size_t()));
|
||||
static_assert(hasSerializedSize<KVCacheEventDiff<SizeType32>>(size_t()));
|
||||
static_assert(hasSerializedSize<KVCacheUpdatedData>(size_t()));
|
||||
static_assert(hasSerializedSize<tensorrt_llm::runtime::UniqueToken>(size_t()));
|
||||
|
||||
template <typename T>
|
||||
size_t serializedSize(T const& data)
|
||||
@ -219,6 +227,14 @@ static_assert(hasSerialize<ContextPhaseParams>(nullptr));
|
||||
static_assert(!hasSerialize<std::string>(nullptr));
|
||||
static_assert(!hasSerialize<std::optional<float>>(nullptr));
|
||||
static_assert(hasSerialize<CacheTransceiverConfig>(nullptr));
|
||||
static_assert(hasSerialize<KVCacheEvent>(nullptr));
|
||||
static_assert(hasSerialize<KVCacheCreatedData>(nullptr));
|
||||
static_assert(hasSerialize<KVCacheStoredData>(nullptr));
|
||||
static_assert(hasSerialize<KVCacheStoredBlockData>(nullptr));
|
||||
static_assert(hasSerialize<KVCacheRemovedData>(nullptr));
|
||||
static_assert(hasSerialize<KVCacheEventDiff<SizeType32>>(nullptr));
|
||||
static_assert(hasSerialize<KVCacheUpdatedData>(nullptr));
|
||||
static_assert(hasSerialize<tensorrt_llm::runtime::UniqueToken>(nullptr));
|
||||
|
||||
template <typename T>
|
||||
void serialize(T const& data, std::ostream& os)
|
||||
@ -291,6 +307,22 @@ struct get_variant_alternative_type
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T deserialize(std::istream& is);
|
||||
|
||||
// Helper function to deserialize variant by index using template recursion
|
||||
template <typename T, std::size_t... Is>
|
||||
T deserializeVariantByIndex(std::istream& is, std::size_t index, std::index_sequence<Is...> /*indices*/)
|
||||
{
|
||||
T result;
|
||||
bool found = ((Is == index ? (result = deserialize<std::variant_alternative_t<Is, T>>(is), true) : false) || ...);
|
||||
if (!found)
|
||||
{
|
||||
TLLM_THROW("Invalid variant index during deserialization: " + std::to_string(index));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Deserialize
|
||||
template <typename T>
|
||||
T deserialize(std::istream& is)
|
||||
@ -511,6 +543,38 @@ T deserialize(std::istream& is)
|
||||
{
|
||||
return Serialization::deserializeCacheTransceiverConfig(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheEvent>)
|
||||
{
|
||||
return Serialization::deserializeKVCacheEvent(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheCreatedData>)
|
||||
{
|
||||
return Serialization::deserializeKVCacheCreatedData(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheStoredData>)
|
||||
{
|
||||
return Serialization::deserializeKVCacheStoredData(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheStoredBlockData>)
|
||||
{
|
||||
return Serialization::deserializeKVCacheStoredBlockData(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheRemovedData>)
|
||||
{
|
||||
return Serialization::deserializeKVCacheRemovedData(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheEventDiff<SizeType32>>)
|
||||
{
|
||||
return Serialization::deserializeKVCacheEventDiff<SizeType32>(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheUpdatedData>)
|
||||
{
|
||||
return Serialization::deserializeKVCacheUpdatedData(is);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, tensorrt_llm::runtime::UniqueToken>)
|
||||
{
|
||||
return Serialization::deserializeUniqueToken(is);
|
||||
}
|
||||
// Optional
|
||||
else if constexpr (std::is_same_v<T, std::optional<typename ValueType<T>::type>>)
|
||||
{
|
||||
@ -547,23 +611,7 @@ T deserialize(std::istream& is)
|
||||
std::size_t index = 0;
|
||||
is.read(reinterpret_cast<char*>(&index), sizeof(index));
|
||||
|
||||
// TODO: Is there a better way to implement this?
|
||||
T data;
|
||||
if (index == 0)
|
||||
{
|
||||
using U = std::variant_alternative_t<0, T>;
|
||||
data = deserialize<U>(is);
|
||||
}
|
||||
else if (index == 1)
|
||||
{
|
||||
using U = std::variant_alternative_t<1, T>;
|
||||
data = deserialize<U>(is);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Serialization of variant of size > 2 is not supported.");
|
||||
}
|
||||
return data;
|
||||
return deserializeVariantByIndex<T>(is, index, std::make_index_sequence<std::variant_size_v<T>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -256,9 +256,9 @@ public:
|
||||
constexpr int SF_VEC_SIZE = 16;
|
||||
using PackedVec = PackedVec<DType>;
|
||||
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&val);
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt, token_id,
|
||||
m_access_id_in_token, std::nullopt, m_params.hidden_dim,
|
||||
reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout);
|
||||
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt, token_id, m_access_id_in_token,
|
||||
std::nullopt, m_params.hidden_dim / SF_VEC_SIZE, reinterpret_cast<uint32_t*>(m_params.scale_out),
|
||||
m_params.layout);
|
||||
reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id]
|
||||
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, m_scale_factor, sf_out);
|
||||
}
|
||||
|
||||
@ -132,7 +132,7 @@ struct AllReduceFusionParams
|
||||
float rms_eps;
|
||||
float* scale_factor;
|
||||
bool use_oneshot;
|
||||
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED;
|
||||
QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED;
|
||||
cudaStream_t stream;
|
||||
AllReduceFusionPattern pattern;
|
||||
bool trigger_completion_at_end = true;
|
||||
|
||||
@ -99,15 +99,15 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
|
||||
uint32_t* offset_access_ptr;
|
||||
uint32_t* buffer_flags;
|
||||
|
||||
__device__ explicit LamportFlags(uint32_t* buffer_flags)
|
||||
__device__ explicit LamportFlags(uint32_t* buffer_flags, uint32_t buffer_size)
|
||||
: offset_access_ptr(&buffer_flags[4])
|
||||
, buffer_flags(buffer_flags)
|
||||
, buffer_size(buffer_size)
|
||||
{
|
||||
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
|
||||
buffer_size = flag.z;
|
||||
input_offset = flag.x * (buffer_size << 1U);
|
||||
clear_offset = flag.y * (buffer_size << 1U);
|
||||
num_tokens_prev = flag.w;
|
||||
num_tokens_prev = flag.z;
|
||||
}
|
||||
|
||||
__device__ void cta_arrive()
|
||||
@ -135,7 +135,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
|
||||
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
|
||||
buffer_flags[0] = (flag.x + 1) % 3;
|
||||
buffer_flags[1] = (flag.y + 1) % 3;
|
||||
buffer_flags[3] = num_tokens;
|
||||
buffer_flags[2] = num_tokens;
|
||||
*(offset_access_ptr) = 0;
|
||||
}
|
||||
}
|
||||
@ -144,7 +144,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags
|
||||
|
||||
template <int WORLD_SIZE, typename T>
|
||||
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
|
||||
int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results)
|
||||
int buffer_M, int token_dim, int rank, uint32_t buffer_size, uint32_t* buffer_flags, bool wait_for_results)
|
||||
{
|
||||
int elt = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
if (elt >= token_dim)
|
||||
@ -155,7 +155,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
LamportFlags flags(buffer_flags);
|
||||
LamportFlags flags(buffer_flags, buffer_size);
|
||||
|
||||
// Capture the number of tokens in previous iteration so that we can properly clear the buffer
|
||||
// The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
|
||||
@ -217,15 +217,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
|
||||
// Similarly clear broadcast buffer here
|
||||
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
|
||||
if (elt < token_dim)
|
||||
{
|
||||
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
|
||||
if (clr_token_idx < buffer_M)
|
||||
// Similarly clear broadcast buffer here
|
||||
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
|
||||
{
|
||||
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
|
||||
= fromFloat<T>(-0.f);
|
||||
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
|
||||
if (clr_token_idx < buffer_M)
|
||||
{
|
||||
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
|
||||
= fromFloat<T>(-0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,20 +242,24 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
|
||||
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
|
||||
{
|
||||
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
|
||||
uint64_t elt_load_offset = blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
|
||||
if (elt_load_offset < token_dim)
|
||||
{
|
||||
uint64_t current_pos = blockIdx.x * token_dim + elt_load_offset;
|
||||
|
||||
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
|
||||
// We have 2 assumptions here:
|
||||
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
|
||||
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
|
||||
float2 val = loadfloat2(lamport_ptr);
|
||||
while (isNegZero(*(T*) &val))
|
||||
{
|
||||
val = loadfloat2(lamport_ptr);
|
||||
}
|
||||
if (output_ptr)
|
||||
{
|
||||
*((float2*) &output_ptr[current_pos]) = val;
|
||||
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
|
||||
// We have 2 assumptions here:
|
||||
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
|
||||
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
|
||||
float2 val = loadfloat2(lamport_ptr);
|
||||
while (isNegZero(*(T*) &val))
|
||||
{
|
||||
val = loadfloat2(lamport_ptr);
|
||||
}
|
||||
if (output_ptr)
|
||||
{
|
||||
*((float2*) &output_ptr[current_pos]) = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -263,10 +269,11 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
|
||||
}
|
||||
|
||||
#define LAUNCH_ALL_REDUCE_KERNEL(WORLD_SIZE, T) \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, \
|
||||
reinterpret_cast<T*>(params.output), reinterpret_cast<T*>(params.input), \
|
||||
reinterpret_cast<T**>(params.buffer_ptrs_dev), (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, \
|
||||
params.token_dim, params.rank, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results));
|
||||
TLLM_CUDA_CHECK( \
|
||||
cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, reinterpret_cast<T*>(params.output), \
|
||||
reinterpret_cast<T*>(params.input), reinterpret_cast<T**>(params.buffer_ptrs_dev), \
|
||||
(T*) params.multicast_ptr, params.num_tokens, params.buffer_M, params.token_dim, params.rank, \
|
||||
params.buffer_size, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results));
|
||||
|
||||
void twoshot_allreduce_op(AllReduceParams const& params)
|
||||
{
|
||||
@ -369,20 +376,33 @@ inline __device__ T add(T a, T b)
|
||||
}
|
||||
|
||||
#define FINAL_MASK 0xffffffff
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val)
|
||||
{
|
||||
// Get the actual number of active threads in this warp
|
||||
int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1)));
|
||||
unsigned int mask = (1U << active_warp_size) - 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
{
|
||||
if (offset < active_warp_size)
|
||||
{
|
||||
val = add<T>(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE));
|
||||
}
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
inline __device__ float block_reduce_sum(float val)
|
||||
{
|
||||
__shared__ float smem[32];
|
||||
int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32;
|
||||
__shared__ float smem[WARP_SIZE];
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps
|
||||
|
||||
val = warpReduceSum(val);
|
||||
if (lane_id == 0)
|
||||
{
|
||||
@ -391,6 +411,7 @@ inline __device__ float block_reduce_sum(float val)
|
||||
__syncthreads();
|
||||
val = lane_id < warp_num ? smem[lane_id] : 0.f;
|
||||
val = warpReduceSum(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
@ -410,7 +431,7 @@ __device__ float4 loadfloat4(void const* ptr)
|
||||
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
|
||||
__global__ void __launch_bounds__(128, 1)
|
||||
RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon,
|
||||
T_IN const* residual, int batch_size, uint32_t* buffer_flags)
|
||||
T_IN const* residual, int batch_size, uint32_t buffer_size, uint32_t* buffer_flags)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
static bool const LAMPORT = true;
|
||||
@ -433,7 +454,7 @@ __global__ void __launch_bounds__(128, 1)
|
||||
|
||||
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
|
||||
|
||||
LamportFlags flags(buffer_flags);
|
||||
LamportFlags flags(buffer_flags, buffer_size);
|
||||
T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size];
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
@ -598,16 +619,15 @@ __global__ void __launch_bounds__(128, 1)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int H_DIM>
|
||||
template <typename T, int H_DIM, int NUM_THREADS>
|
||||
void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T const* gamma, double epsilon,
|
||||
T const* residual, uint32_t* buffer_flags, int batch, cudaStream_t stream)
|
||||
T const* residual, uint32_t buffer_size, uint32_t* buffer_flags, int batch, cudaStream_t stream)
|
||||
{
|
||||
|
||||
// input to rmsnorm is the buffer in the twoshot ar
|
||||
// We should use prenorm output to determine the actual used size
|
||||
float _epsilon{static_cast<float>(epsilon)};
|
||||
|
||||
static constexpr int NUM_THREADS = 128;
|
||||
static constexpr int CGA_THREADS = NUM_THREADS;
|
||||
constexpr int iters = H_DIM / CGA_THREADS;
|
||||
|
||||
@ -628,28 +648,34 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
|
||||
&RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
|
||||
config.dynamicSmemBytes = shmem_size;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, prenorm_output, normed_output,
|
||||
input, gamma, _epsilon, residual, batch, buffer_flags));
|
||||
input, gamma, _epsilon, residual, batch, buffer_size, buffer_flags));
|
||||
}
|
||||
|
||||
#define LAUNCH_RMSNORM_KERNEL(T, H_DIM) \
|
||||
twoshot_rmsnorm<T, H_DIM>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \
|
||||
#define LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS) \
|
||||
twoshot_rmsnorm<T, H_DIM, NUM_THREADS>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \
|
||||
static_cast<T const*>(params.input), static_cast<T const*>(params.gamma), params.epsilon, \
|
||||
static_cast<T const*>(params.residual), params.buffer_flags, params.batch, params.stream)
|
||||
static_cast<T const*>(params.residual), params.buffer_size, params.buffer_flags, params.batch, params.stream)
|
||||
|
||||
void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
auto dtype = params.dtype;
|
||||
|
||||
#define CASE_DISPATCH_RMSNORM(T, H_DIM, NUM_THREADS) \
|
||||
case H_DIM: LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS); break;
|
||||
|
||||
#define TYPE_DISPATCH_RMSNORM(T) \
|
||||
CASE_DISPATCH_RMSNORM(T, 2048, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 2880, 120) \
|
||||
CASE_DISPATCH_RMSNORM(T, 4096, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 5120, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 7168, 128) \
|
||||
CASE_DISPATCH_RMSNORM(T, 8192, 128)
|
||||
|
||||
if (dtype == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
switch (params.hidden_dim)
|
||||
{
|
||||
case 2048: LAUNCH_RMSNORM_KERNEL(float, 2048); break;
|
||||
case 4096: LAUNCH_RMSNORM_KERNEL(float, 4096); break;
|
||||
// Llama-4 Hidden Dimension
|
||||
case 5120: LAUNCH_RMSNORM_KERNEL(float, 5120); break;
|
||||
// DeepSeek Hidden Dimension
|
||||
case 7168: LAUNCH_RMSNORM_KERNEL(float, 7168); break;
|
||||
case 8192: LAUNCH_RMSNORM_KERNEL(float, 8192); break;
|
||||
TYPE_DISPATCH_RMSNORM(float);
|
||||
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
|
||||
}
|
||||
}
|
||||
@ -657,13 +683,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
switch (params.hidden_dim)
|
||||
{
|
||||
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 2048); break;
|
||||
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 4096); break;
|
||||
// Llama-4 Hidden Dimension
|
||||
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 5120); break;
|
||||
// DeepSeek Hidden Dimension
|
||||
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 7168); break;
|
||||
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 8192); break;
|
||||
TYPE_DISPATCH_RMSNORM(__nv_bfloat16);
|
||||
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
|
||||
}
|
||||
}
|
||||
@ -671,13 +691,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
switch (params.hidden_dim)
|
||||
{
|
||||
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break;
|
||||
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break;
|
||||
// Llama-4 Hidden Dimension
|
||||
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break;
|
||||
// DeepSeek Hidden Dimension
|
||||
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break;
|
||||
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break;
|
||||
TYPE_DISPATCH_RMSNORM(__nv_half);
|
||||
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
|
||||
}
|
||||
}
|
||||
@ -685,6 +699,8 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype.");
|
||||
}
|
||||
#undef TYPE_DISPATCH_RMSNORM
|
||||
#undef CASE_DISPATCH_RMSNORM
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::mnnvl
|
||||
|
||||
@ -30,6 +30,7 @@ struct AllReduceParams
|
||||
int buffer_M;
|
||||
int num_tokens;
|
||||
int token_dim;
|
||||
uint32_t buffer_size;
|
||||
void** buffer_ptrs_dev;
|
||||
void* multicast_ptr;
|
||||
void* buffer_flags;
|
||||
@ -50,6 +51,7 @@ struct RMSNormParams
|
||||
void const* gamma;
|
||||
double epsilon;
|
||||
void* residual;
|
||||
uint32_t buffer_size;
|
||||
uint32_t* buffer_flags;
|
||||
int batch;
|
||||
int hidden_dim;
|
||||
|
||||
@ -150,8 +150,8 @@ __device__ __forceinline__ void fused_op(
|
||||
constexpr int SF_VEC_SIZE = 16;
|
||||
using PackedVec = PackedVec<DType>;
|
||||
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&norm_val);
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt /* batchIdx */,
|
||||
token_id, access_id_in_token, std::nullopt /* numRows */, params.hidden_dim,
|
||||
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt /* batchIdx */, token_id,
|
||||
access_id_in_token, std::nullopt /* numRows */, params.hidden_dim / SF_VEC_SIZE,
|
||||
reinterpret_cast<uint32_t*>(params.scale_out), params.layout);
|
||||
reinterpret_cast<uint32_t*>(params.quant_out)[access_id]
|
||||
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, *params.scale_factor, sf_out);
|
||||
|
||||
@ -55,7 +55,7 @@ struct AllReduceFusionParams
|
||||
void* rms_gamma;
|
||||
float rms_eps;
|
||||
float* scale_factor;
|
||||
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED;
|
||||
QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED;
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
|
||||
@ -216,7 +216,8 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90(
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_64_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
@ -1820,6 +1821,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, true, false, nullptr},
|
||||
@ -1833,6 +1835,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90_kernel", 73984, 384, 64, 0, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
|
||||
@ -1873,11 +1876,13 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 2, false, true, true, true, false, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 2, false, true, true, true, false, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 0, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_custom_mask_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90},
|
||||
@ -1887,7 +1892,10 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, true, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 2, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 40, 40, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 48, 48, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_48_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90},
|
||||
@ -2522,7 +2530,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
@ -2531,7 +2541,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
@ -3144,7 +3156,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm89_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_nl_tiled},
|
||||
#endif
|
||||
@ -3756,7 +3770,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm80_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_nl_tiled},
|
||||
#endif
|
||||
@ -4368,13 +4384,17 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm86_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_nl_tiled},
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_100
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_nl_tiled},
|
||||
#endif
|
||||
@ -4784,7 +4804,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_softcapping_sm120_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 128, 128, 32, 32, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_kernel_nl", 12288, 128, 128, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl},
|
||||
@ -4874,7 +4896,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl},
|
||||
@ -4883,7 +4907,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 16, 16, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_causal_sm120_kernel_nl_tiled", 16384, 128, 128, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled},
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d5bb139b12206a563daec9fa473dda422319bde5ae5f965d37cf5ca67d325c49
|
||||
size 1005546
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user