diff --git a/.coderabbit.yaml b/.coderabbit.yaml index a7d05d70d5..1644fad272 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -20,7 +20,7 @@ language: "en-US" reviews: profile: chill auto_title_placeholder: '@coderabbitai title' - auto_title_instructions: 'Should follow the format: "[fix/feat/doc/infra/...] \". Keep it concise.' + auto_title_instructions: 'Format: "[] ". Category must be one of: fix, feat, doc, infra, style, refactor, perf, test, chore, revert. Enclose the category in square brackets. Title should be concise (<= 60 chars). Example: "[feat] Add logit_bias support".' commit_status: false collapse_walkthrough: true assess_linked_issues: true diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..733d2aced5 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,24 @@ +# Auto basic formatting when saving file with EditorConfig https://editorconfig.org/ + +# top-most EditorConfig file +root = true + +[*] +end_of_line = lf +trim_trailing_whitespace = true +insert_final_newline = true + +# make +[Makefile*] +indent_style = tab +indent_size = 4 + +# c++ +[*.{cpp,cu,h}] +indent_style = space +indent_size = 4 + +# python +[*.py] +indent_style = space +indent_size = 4 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6fc19aa2f1..afcf5adcda 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,13 +6,39 @@ # Without approval from a member of this team, PRs cannot be merged to release branches. # * @NVIDIA/trt-llm-release-branch-approval +## TensorRT-LLM Infra +### CI +/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs +### Setup +/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs +### Github workflows +/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs +/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs + +## TensorRT-LLM - Docs +/docs @NVIDIA/trt-llm-doc-owners + +## Examples +/examples @NVIDIA/trt-llm-doc-owners + +## TensorRT-LLM - Triton backend +/triton_backend @NVIDIA/trt-llm-triton-backend-devs + # TensorRT-LLM Pytorch backend /tensorrt_llm/_torch @NVIDIA/trt-llm-torch-devs + +## TensorRT-LLM Pytorch - Modules +/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules + +## TensorRT-LLM Pytorch Models +/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs +/examples/models @NVIDIA/trt-llm-torch-models-devs @NVIDIA/trt-llm-doc-owners + ## TensorRT-LLM Pytorch backend - runtime /tensorrt_llm/_torch/pyexecutor @NVIDIA/trt-llm-torch-runtime-devs ## TensorRT-LLM Pytorch backend - AutoDeploy flow /tensorrt_llm/_torch/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs -/tensorrt_llm/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs +/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs @NVIDIA/trt-llm-doc-owners ## TensorRT-LLM Pytorch - Speculative Decoding /tensorrt_llm/_torch/speculative @NVIDIA/trt-llm-torch-spec-decoding @@ -31,12 +57,6 @@ /tensorrt_llm/_torch/attention_backend @NVIDIA/trt-llm-torch-attention-devs /tensorrt_llm/_torch/modules/attention.py @NVIDIA/trt-llm-torch-attention-devs -## TensorRT-LLM Pytorch - Modules -/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules - - -## TensorRT-LLM Pytorch Models -/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs ### TensorRT-LLM Pytorch - Models - Gemma /tensorrt_llm/_torch/models/modeling_gemma3.py @NVIDIA/trt-llm-torch-models-gemma-devs @NVIDIA/trt-llm-torch-models-devs @@ -87,7 +107,7 @@ /tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs /tensorrt_llm/_torch/models/modeling_nemotron_h.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs /tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs -/tensorrt_llm/_torch/pyexecutor/resource_manager.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-runtime-devs @NVIDIA/trt-llm-torch-models-devs +/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs /tensorrt_llm/_torch/modules/mamba @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs /tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs /tests/unittest/_torch/modeling/test_modeling_nemotron.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs @@ -108,8 +128,6 @@ /cpp/tensorrt_llm/runtime/loraUtils.cpp @NVIDIA/trt-llm-torch-peft /cpp/tensorrt_llm/runtime/loraUtils.h @NVIDIA/trt-llm-torch-peft -## TensorRT-LLM - Triton backend -/triton_backend @NVIDIA/trt-llm-triton-backend-devs ## TensorRT-LLM trtllm-bench Reviewers /tensorrt_llm/bench @NVIDIA/trtllm-bench-reviewers @@ -121,10 +139,9 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers /tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs ## TensorRT-LLM LLM Disaggregated -/examples/disaggregated @NVIDIA/trt-llm-disagg-devs +/examples/disaggregated @NVIDIA/trt-llm-disagg-devs @NVIDIA/trt-llm-doc-owners /tensorrt_llm/disaggregated_params.py @NVIDIA/trt-llm-disagg-devs /tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @NVIDIA/trt-llm-disagg-devs -/tensorrt_llm/_torch/pyexecutor/py_executor.py @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheFormatter.h @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @NVIDIA/trt-llm-disagg-devs @@ -135,19 +152,6 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers /cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @NVIDIA/trt-llm-disagg-devs -## TensorRT-LLM Infra - -### CI -/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs -### Setup -/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs -### Github workflows -/tensorrt_llm/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs -/tensorrt_llm/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs - -## TensorRT-LLM - Docs -/docs @NVIDIA/trt-llm-doc-owners -/examples @NVIDIA/trt-llm-doc-owners # The rule below requires that any PR modifying public APIs must be approved by at least one member # of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team. diff --git a/.github/ISSUE_TEMPLATE/01-installation.yml b/.github/ISSUE_TEMPLATE/01-installation.yml new file mode 100644 index 0000000000..fd24fd93f0 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/01-installation.yml @@ -0,0 +1,66 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/200-installation.yml +name: 🛠️ Installation +description: Report an issue here when you hit errors during installation. +title: "[Installation]: " +labels: ["Installation"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: System Info + description: | + Please provide the following system information to help us debug your installation issue: + + ```bash + # System information + cat /etc/os-release + nvidia-smi + nvcc --version + python --version + pip list | grep -E "(tensorrt|torch|cuda)" + + # TensorRT-LLM installation method and version + pip show tensorrt_llm + ``` + value: | + **System Information:** + - OS: + - Python version: + - CUDA version: + - GPU model(s): + - Driver version: + - TensorRT version: + - PyTorch version: + - TensorRT-LLM version: + + **Detailed output:** + ```text + Paste the output of the above commands here + ``` + validations: + required: true +- type: textarea + attributes: + label: How you are installing TensorRT-LLM + description: | + Paste the full command you are trying to execute or describe your installation method. + value: | + ```sh + # Installation command or method + pip install tensorrt_llm + ``` +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [installation documentation](https://nvidia.github.io/TensorRT-LLM/installation/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/02-new-model.yml b/.github/ISSUE_TEMPLATE/02-new-model.yml new file mode 100644 index 0000000000..688c11866f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/02-new-model.yml @@ -0,0 +1,41 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/600-new-model.yml +name: 🤗 Support request for a new model from huggingface +description: Submit a proposal/request for a new model from huggingface +title: "[New Model]: " +labels: ["new model"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). + + #### We also highly recommend you read https://nvidia.github.io/TensorRT-LLM/architecture/add-model.html first to understand how to add a new model. +- type: textarea + attributes: + label: The model to consider. + description: > + A huggingface identifier, pointing to the model, e.g. `meta-llama/Llama-3.1-8B-Instruct` . + validations: + required: true +- type: textarea + attributes: + label: The closest model TensorRT-LLM already supports. + description: > + Here is the list of models already supported by TensorRT-LLM: https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/models (TRT backend) and https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/_torch/models (Pytorch backend) . Which model is the most similar to the model you want to add support for? +- type: textarea + attributes: + label: What's your difficulty of supporting the model you want? + description: > + For example, any new operators or new architecture? +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/03-documentation.yml b/.github/ISSUE_TEMPLATE/03-documentation.yml new file mode 100644 index 0000000000..df7643337b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/03-documentation.yml @@ -0,0 +1,31 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/100-documentation.yml +name: 📚 Documentation +description: Report an issue related to https://nvidia.github.io/TensorRT-LLM/ +title: "[Doc]: " +labels: ["Documentation"] +assignees: ["nv-guomingz"] + +body: +- type: textarea + attributes: + label: 📚 The doc issue + description: > + A clear and concise description of what content in https://nvidia.github.io/TensorRT-LLM/ is an issue. + validations: + required: true +- type: textarea + attributes: + label: Suggest a potential alternative/fix + description: > + Tell us how we could improve the documentation in this regard. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/04-questions.yml b/.github/ISSUE_TEMPLATE/04-questions.yml new file mode 100644 index 0000000000..75a9416e92 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/04-questions.yml @@ -0,0 +1,62 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/300-usage.yml +name: 💻 Questions +description: Raise an issue here if you don't know how to use TensorRT-LLM. +title: "[Usage]: " +labels: ["question"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: System Info + description: | + Please provide the following system information to help us debug your usage issue: + + ```bash + # System information + nvidia-smi + python --version + pip show tensorrt_llm + ``` + value: | + **System Information:** + - OS: + - Python version: + - CUDA version: + - GPU model(s): + - Driver version: + - TensorRT-LLM version: + + **Detailed output:** + ```text + Paste the output of the above commands here + ``` + validations: + required: true +- type: textarea + attributes: + label: How would you like to use TensorRT-LLM + description: | + A detailed description of how you want to use TensorRT-LLM. + value: | + I want to run inference of a [specific model](put Hugging Face link here). I don't know how to integrate it with TensorRT-LLM or optimize it for my use case. + + **Specific questions:** + - Model: + - Use case (e.g., chatbot, batch inference, real-time serving): + - Expected throughput/latency requirements: + - Multi-GPU setup needed: +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/05-feature-request.yml b/.github/ISSUE_TEMPLATE/05-feature-request.yml new file mode 100644 index 0000000000..32c1ee43c7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/05-feature-request.yml @@ -0,0 +1,40 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/500-feature-request.yml +name: 🚀 Feature request +description: Submit a proposal/request for a new TensorRT-LLM feature +title: "[Feature]: " +labels: ["feature request"] +assignees: ["laikhtewari"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: 🚀 The feature, motivation and pitch + description: > + A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/06-bug-report.yml b/.github/ISSUE_TEMPLATE/06-bug-report.yml new file mode 100644 index 0000000000..c41ff62ded --- /dev/null +++ b/.github/ISSUE_TEMPLATE/06-bug-report.yml @@ -0,0 +1,191 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/400-bug-report.yml +name: "🐛 Bug Report" +description: Submit a bug report to help us improve TensorRT-LLM +title: "[Bug]: " +labels: [ "bug" ] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: markdown + attributes: + value: | + ⚠️ **SECURITY WARNING:** Please review any text you paste to ensure it does not contain sensitive information such as: + - API tokens or keys (e.g., Hugging Face tokens, OpenAI API keys) + - Passwords or authentication credentials + - Private URLs or endpoints + - Personal or confidential data + + Consider redacting or replacing sensitive values with placeholders like `<YOUR_TOKEN_HERE>` when sharing configuration or code examples. +- type: textarea + id: system-info + attributes: + label: System Info + description: Please share your system info with us. + placeholder: | + - CPU architecture (e.g., x86_64, aarch64) + - CPU/Host memory size (if known) + - GPU properties + - GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S) + - GPU memory size (if known) + - Clock frequencies used (if applicable) + - Libraries + - TensorRT-LLM branch or tag (e.g., main, v0.7.1) + - TensorRT-LLM commit (if known) + - Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used + - Container used (if running TensorRT-LLM in a container) + - NVIDIA driver version + - OS (Ubuntu 24.04, CentOS 8) + - Any other information that may be useful in reproducing the bug + + **Commands to gather system information:** + ```bash + nvidia-smi + nvcc --version + python --version + pip show tensorrt_llm tensorrt torch + ``` + validations: + required: true + +- type: textarea + id: who-can-help + attributes: + label: Who can help? + description: | + To expedite the response to your issue, it would be helpful if you could identify the appropriate person + to tag using the **@** symbol. Here is a general guideline on **whom to tag**. + + Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag, + you can leave it blank, and a core maintainer will make sure to involve the appropriate person. + + Please tag fewer than 3 people. + + Quantization: @Tracin + + Documentation: @juney-nvidia + + Feature request: @laikhtewari + + Performance: @kaiyux + + placeholder: "@Username ..." + +- type: checkboxes + id: information-scripts-examples + attributes: + label: Information + description: 'The problem arises when using:' + options: + - label: "The official example scripts" + - label: "My own modified scripts" + +- type: checkboxes + id: information-tasks + attributes: + label: Tasks + description: "The tasks I am working on are:" + options: + - label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)" + - label: "My own task or dataset (give details below)" + +- type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction + description: | + Please provide a clear and concise description of what the bug is and how to reproduce it. + + If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: + + ```python + from tensorrt_llm import LLM + from tensorrt_llm.sampling_params import SamplingParams + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct") + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + ``` + + If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. + + Remember to use code tags to properly format your code. You can refer to the + link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting. + + Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code. + It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes. + + Please set the environment variable `export TLLM_DEBUG_MODE=1` to turn on more logging to help debugging potential issues. + + placeholder: | + Steps to reproduce the behavior: + + 1. + 2. + 3. + + ```python + # Sample code to reproduce the problem + ``` + + ``` + The error message you got, with the full traceback and the error logs. + ``` + +- type: textarea + id: expected-behavior + validations: + required: true + attributes: + label: Expected behavior + description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible." + +- type: textarea + id: actual-behavior + validations: + required: true + attributes: + label: actual behavior + description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible." + +- type: textarea + id: additional-notes + validations: + required: true + attributes: + label: additional notes + description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)." + +- type: markdown + attributes: + value: | + ⚠️ Please separate bugs of `transformers`, `pytorch` implementation or usage from bugs of `TensorRT-LLM`. + + - If the error only appears in TensorRT-LLM, please provide the detailed script of how you run `TensorRT-LLM`, also highlight the difference and what you expect. + + Thanks for reporting 🙏! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/07-performance-discussion.yml b/.github/ISSUE_TEMPLATE/07-performance-discussion.yml new file mode 100644 index 0000000000..feb3b02501 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/07-performance-discussion.yml @@ -0,0 +1,74 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/700-performance-discussion.yml +name: ⚡ Discussion on the performance of TensorRT-LLM +description: Submit a proposal/discussion about the performance of TensorRT-LLM +title: "[Performance]: " +labels: ["Performance"] +assignees: ["byshiue", "kaiyux"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: Proposal to improve performance + description: > + How do you plan to improve TensorRT-LLM's performance? + validations: + required: false +- type: textarea + attributes: + label: Report of performance regression + description: > + Please provide detailed description of performance comparison to confirm the regression. You may want to run the benchmark script at https://github.com/NVIDIA/TensorRT-LLM/tree/main/benchmarks . + validations: + required: false +- type: textarea + attributes: + label: Misc discussion on performance + description: > + Anything about the performance. + validations: + required: false +- type: textarea + attributes: + label: Your current environment (if you think it is necessary) + description: | + Please provide the following system information to help with performance analysis: + + ```bash + # System information + nvidia-smi + nvcc --version + python --version + pip show tensorrt_llm tensorrt torch + ``` + value: | + **System Information:** + - OS: + - Python version: + - CUDA version: + - GPU model(s): + - Driver version: + - TensorRT version: + - PyTorch version: + - TensorRT-LLM version: + + **Detailed output:** + ```text + Paste the output of the above commands here + ``` + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/08-RFC.yml b/.github/ISSUE_TEMPLATE/08-RFC.yml new file mode 100644 index 0000000000..20d505171b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/08-RFC.yml @@ -0,0 +1,58 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/750-RFC.yml +name: 💬 Request for comments (RFC). +description: Ask for feedback on major architectural changes or design choices. +title: "[RFC]: " +labels: ["RFC"] +assignees: ["laikhtewari"] + +body: +- type: markdown + attributes: + value: > + #### Please take a look at previous [RFCs](https://github.com/NVIDIA/TensorRT-LLM/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference. +- type: textarea + attributes: + label: Motivation. + description: > + The motivation of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Proposed Change. + description: > + The proposed change of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Feedback Period. + description: > + The feedback period of the RFC. Usually at least one week. + validations: + required: false +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. + validations: + required: false +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! The TensorRT-LLM team reviews RFCs during regular team meetings. Most RFCs can be discussed online, but you can also reach out to the team through GitHub discussions or issues for additional feedback. +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml deleted file mode 100644 index 10591e6b23..0000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ /dev/null @@ -1,114 +0,0 @@ -name: "Bug Report" -description: Submit a bug report to help us improve TensorRT-LLM -labels: [ "bug" ] -body: - - type: textarea - id: system-info - attributes: - label: System Info - description: Please share your system info with us. - placeholder: | - - CPU architecture (e.g., x86_64, aarch64) - - CPU/Host memory size (if known) - - GPU properties - - GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S) - - GPU memory size (if known) - - Clock frequencies used (if applicable) - - Libraries - - TensorRT-LLM branch or tag (e.g., main, v0.7.1) - - TensorRT-LLM commit (if known) - - Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used - - Container used (if running TensorRT-LLM in a container) - - NVIDIA driver version - - OS (Ubuntu 24.04, CentOS 8) - - Any other information that may be useful in reproducing the bug - validations: - required: true - - - type: textarea - id: who-can-help - attributes: - label: Who can help? - description: | - To expedite the response to your issue, it would be helpful if you could identify the appropriate person - to tag using the **@** symbol. Here is a general guideline on **whom to tag**. - - Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag, - you can leave it blank, and a core maintainer will make sure to involve the appropriate person. - - Please tag fewer than 3 people. - - Quantization: @Tracin - - Documentation: @juney-nvidia - - Feature request: @ncomly-nvidia - - Performance: @kaiyux - - placeholder: "@Username ..." - - - type: checkboxes - id: information-scripts-examples - attributes: - label: Information - description: 'The problem arises when using:' - options: - - label: "The official example scripts" - - label: "My own modified scripts" - - - type: checkboxes - id: information-tasks - attributes: - label: Tasks - description: "The tasks I am working on are:" - options: - - label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)" - - label: "My own task or dataset (give details below)" - - - type: textarea - id: reproduction - validations: - required: true - attributes: - label: Reproduction - description: | - Kindly share a code example that demonstrates the issue you encountered. It is recommending to provide a code snippet directly. - Additionally, if you have any error messages, or stack traces related to the problem, please include them here. - - Remember to use code tags to properly format your code. You can refer to the - link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting. - - Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code. - It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes. - - placeholder: | - Steps to reproduce the behavior: - - 1. - 2. - 3. - - - type: textarea - id: expected-behavior - validations: - required: true - attributes: - label: Expected behavior - description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible." - - - type: textarea - id: actual-behavior - validations: - required: true - attributes: - label: actual behavior - description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible." - - - type: textarea - id: additioanl-notes - validations: - required: true - attributes: - label: additional notes - description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)." diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..93ef69beeb --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: 🤔 Questions + url: https://github.com/NVIDIA/TensorRT-LLM/discussions + about: Ask questions and discuss with other TensorRT-LLM community members diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 45f9ebf7f1..4665a9682a 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -18,6 +18,14 @@ Examples: - [https://nvbugs/1234567][fix] Fix some bugs - [#1234][doc] Update documentation - [None][chore] Minor clean-up + +Alternative (faster) way using CodeRabbit AI: + +**[JIRA ticket/NVBugs ID/GitHub issue/None] @coderabbitai title** + +NOTE: "@coderabbitai title" will be replaced by the title generated by CodeRabbit AI, that includes the "[type]" and title. +For more info, see /.coderabbit.yaml. + --> ## Description diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 00516b1afa..74f830f07d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: args: [--allow-multiple-documents] exclude: ".*/gitlab/.*.yml" - id: trailing-whitespace - exclude: '\.patch$' + exclude: '\.(patch|md)$' - id: check-toml - id: mixed-line-ending args: [--fix=lf] diff --git a/README.md b/README.md index bb58d309a5..f6625a0559 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ TensorRT-LLM [![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads) [![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt) -[![version](https://img.shields.io/badge/release-1.0.0rc6-green)](./tensorrt_llm/version.py) +[![version](https://img.shields.io/badge/release-1.1.0rc1-green)](./tensorrt_llm/version.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap) @@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer ## Useful Links - [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM. - [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM. -- [AutoDeploy](./examples/auto_deploy/README.md): An experimental backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models. +- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models. - [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news. diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index 0b89bae602..ae3287faf0 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -336,7 +336,7 @@ cd cpp/build `disaggServerBenchmark` only supports `decoder-only` models. Here is the basic usage: ``` -export TRTLLM_USE_MPI_KVCACHE=1 +export TRTLLM_USE_UCX_KVCACHE=1 mpirun -n ${proc} benchmarks/disaggServerBenchmark --context_engine_dirs ${context_engine_0},${context_engine_1}...,${context_engine_{m-1}} \ --generation_engine_dirs ${generation_engine_0},${generation_engine_1}...,${generation_engine_{n-1}} --dataset ${dataset_path} ``` @@ -344,7 +344,7 @@ This command will launch m context engines and n generation engines. You need to for example: ``` -export TRTLLM_USE_MPI_KVCACHE=1 +export TRTLLM_USE_UCX_KVCACHE=1 mpirun -n 7 benchmarks/disaggServerBenchmark --context_engine_dirs ${llama_7b_tp2_pp1_dir},${llama_7b_tp1_pp1_dir} --generation_engine_dirs ${llama_7b_tp1_pp1_dir},${llama_7b_tp2_pp1_dir} --dataset ${dataset_path} # need 6 gpus and 7 processes to launch the benchmark. diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7c44c1ee0e..4f9d5f0f22 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -495,6 +495,17 @@ if(ENABLE_UCX) if(NOT ${ucx_FOUND}) set(ENABLE_UCX 0) else() + if(DEFINED ENV{GITHUB_MIRROR} AND NOT "$ENV{GITHUB_MIRROR}" STREQUAL "") + if(EXISTS "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake") + file(READ "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake" FILE_CONTENTS) + string( + REPLACE "https://raw.githubusercontent.com/rapidsai/rapids-cmake" + "$ENV{GITHUB_MIRROR}/rapidsai/rapids-cmake/raw/refs/heads" + FILE_CONTENTS "${FILE_CONTENTS}") + file(WRITE "${3RDPARTY_DIR}/ucxx/fetch_rapids.cmake" "${FILE_CONTENTS}") + message(WARNING "Replace UCXX fetch_rapids.cmake with internal mirror") + endif() + endif() # installing ucxx via add_subdirectory results in strange cudart linking # error, thus using their installation script to isolate the installation # process until the issue is understood. And always trigger the build so diff --git a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h index ce42493879..394f7fb7bf 100644 --- a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h +++ b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h @@ -75,27 +75,19 @@ public: std::vector<executor::LookaheadDecodingConfig>> operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, - runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const; + nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, + CudaStream const& runtimeStream, CudaStream const& decoderStream, SizeType32 maxSequenceLength, + SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const; [[nodiscard]] std::tuple<std::vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>> createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds, executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState, - runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, - runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const; private: - //! @brief Initialize the decoder at `batchSlot` with a new `request`. Exposed only for static batching via - //! GptDecoderBatched::newBatch() - static void newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength); - //! @brief Setups decoder internal tensors for new speculative decoding request static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h index a232230c4f..09a96a56ee 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h @@ -18,6 +18,7 @@ #include "tensorrt_llm/executor/executor.h" +#include <atomic> #include <chrono> #include <condition_variable> #include <deque> @@ -36,7 +37,8 @@ using BlockPtr = std::shared_ptr<KVCacheBlock>; class KVCacheEventManager { public: - explicit KVCacheEventManager(size_t maxKVEventEntries); + explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank = std::nullopt, + std::optional<SizeType32> attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5); ~KVCacheEventManager(); KVCacheEventManager(KVCacheEventManager& other) = delete; @@ -61,14 +63,19 @@ public: // Worker thread which adds events to mEvents. void worker(); + // Thread which exchanges events if attentionDP is enabled + void exchangeAttentionDpThread(); + private: // Add an event to mEventQueue void enqueueEvent(executor::KVCacheEvent&& event); /// @brief Flag to terminate the worker - bool mRun; + std::atomic<bool> mRun; /// @brief Worker thread std::thread mWorkerThread; + /// @brief Exchange thread for attention DP events + std::thread mExchangeAttentionDpThread; /// @brief The deque of events std::deque<executor::KVCacheEvent> mEvents; @@ -91,6 +98,17 @@ private: size_t mMaxSize; /// @brief An auto-incrementing event id counter size_t mEventId; + + /// @brief Attention DP ranks and size + /// If set, we will exchange KV cache events and accumulate on rank 0 + std::optional<SizeType32> mAttentionDpRank; + std::optional<SizeType32> mAttentionDpSize; + + /// @brief The period in milliseconds to gather attention DP events across rank + SizeType32 mAttentionDpEventsGatherPeriodMs; + + /// @brief MPI communicator for attention DP + std::unique_ptr<tensorrt_llm::mpi::MpiComm> mMpiComm; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index a0234cbbe4..a49527a615 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -536,8 +536,7 @@ public: SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority, - std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse); + std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse); ~WindowBlockManager(); @@ -633,11 +632,6 @@ public: return mAllBlocksById.at(blockId); } - [[nodiscard]] BlockMapIterRange getBlocksByHash(size_t hash) const - { - return mContextBlocksByHash.equal_range(hash); - } - [[nodiscard]] SizeType32 getTokensPerBlock() const noexcept { return mTokensPerBlock; @@ -723,10 +717,6 @@ public: //! \param blockIds Id of each block. void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds); - void addBlockToHashMap(BlockPtr const& block); - - void removeBlockFromHashMap(BlockPtr const& block); - [[nodiscard]] bool verifyQueueIntegrity(); // Only needed when sliding window attention + paged context fmha are used together. @@ -808,8 +798,6 @@ private: SizeType32 mTokensPerBlock; // List of all blocks by idx std::vector<BlockPtr> mAllBlocksById; - // List of all context blocks by hash - BlockMap mContextBlocksByHash; // Dummy block acting as root for BlockToken searches BlockPtr mCachedBlocksRoot; // KV cache type (self or cross) @@ -841,8 +829,6 @@ private: double mReusedTokens; // Total number of input tokens double mTotalInputTokens; - // Whether or not to maintain a hashmap of blocks. - bool mEnableHashKey; // Whether blocks that are partially matched should be reused. bool mEnablePartialReuse; // Whether partially matched blocks that are already in use should be copied and reused. @@ -863,8 +849,8 @@ public: std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnPartialReuse = true); + std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnPartialReuse = true); BlockManager(BlockManager const&) = delete; BlockManager& operator=(BlockManager const&) = delete; @@ -1081,11 +1067,6 @@ public: return mWindowBlockManagers.at(windowSize).getBlockById(blockId); } - [[nodiscard]] WindowBlockManager::BlockMapIterRange getBlocksByHash(size_t hash, SizeType32 windowSize) const - { - return mWindowBlockManagers.at(windowSize).getBlocksByHash(hash); - } - [[nodiscard]] SizeType32 getNumPrimaryBlocks() const { return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); }); @@ -1096,16 +1077,6 @@ public: return getPool(poolIdx).containsBlockScales; } - void addBlockToHashMap(BlockPtr const& block, SizeType32 windowSize) - { - mWindowBlockManagers.at(windowSize).addBlockToHashMap(block); - } - - void removeBlockFromHashMap(BlockPtr const& block, SizeType32 windowSize) - { - mWindowBlockManagers.at(windowSize).removeBlockFromHashMap(block); - } - //! \brief Store context blocks void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest); @@ -1385,8 +1356,8 @@ public: SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1405,8 +1376,8 @@ public: SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1692,8 +1663,6 @@ private: std::unordered_map<LlmRequest::RequestIdType, GenerationRequest> mSequences; // Whether to cache KV pages for reuse bool mEnableBlockReuse; - // Whether enable finding blocks by their hash, ignored when reuse enabled - bool mEnableHashKey; // Mutex to protect access to mSequences mutable std::mutex mSequencesMtx; // buffers for static tensors, will be created after allocating pools diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 0d087d96c0..e4d13c9e17 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -828,8 +828,10 @@ public: // for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT : LlmRequestState::kCONTEXT_INIT; - mContextCurrentPosition = 0; - mPrepopulatedPromptLen = 0; + mContextCurrentPositionTarget = 0; + mContextCurrentPositionDraft = 0; + mPrepopulatedPromptLenTarget = 0; + mPrepopulatedPromptLenDraft = 0; mContextChunkSize = mPromptLen; mSeqSlot.reset(); } @@ -1049,7 +1051,7 @@ public: [[nodiscard]] SizeType32 getPrepopulatedPromptLen() const { - return mPrepopulatedPromptLen; + return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; } void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock) @@ -1066,7 +1068,10 @@ public: "Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen, promptLen, mRequestId); TLLM_CHECK(prepopulatedPromptLen < promptLen); - mPrepopulatedPromptLen = prepopulatedPromptLen; + + auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; + auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; + prePromptLen = prepopulatedPromptLen; if (prepopulatedPromptLen > 0) { @@ -1081,7 +1086,7 @@ public: chunkSize = flooredEndPosition - prepopulatedPromptLen; TLLM_CHECK(chunkSize <= getContextChunkSize()); } - setContextCurrentPosition(prepopulatedPromptLen); + contextCurrentPosition = prepopulatedPromptLen; setContextChunkSize(chunkSize); if (!isLastContextChunk()) @@ -1522,14 +1527,15 @@ public: void setContextCurrentPosition(SizeType32 contextCurrentPosition) { - mContextCurrentPosition = contextCurrentPosition; + mContextCurrentPositionDraft = contextCurrentPosition; + mContextCurrentPositionTarget = contextCurrentPosition; } /// When chunked, the position of the current chunk is returned. Otherwise, only the beginning /// or end of the context is returned. [[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept { - return mContextCurrentPosition; + return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; } /// Return the length of the context that has not yet been processed. @@ -1570,14 +1576,16 @@ public: { // The number of cached token is encountered in mContextCurrentPosition, // so the start position of the context is mPrepopulatedPromptLen. - return mContextCurrentPosition == mPrepopulatedPromptLen; + return getContextCurrentPosition() == getPrepopulatedPromptLen(); } /// Move the cursor forward one chunk. When not chunked, move forward to the end of the context. void moveToNextContextChunk() { TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase."); - mContextCurrentPosition += getContextChunkSize(); + + mContextCurrentPositionDraft += getContextChunkSize(); + mContextCurrentPositionTarget += getContextChunkSize(); setContextChunkSize(0); } @@ -1843,6 +1851,16 @@ public: return mIsDummyRequest; } + void setUseDraftModel(bool useDraftModel) + { + mUseDraftModel = useDraftModel; + } + + [[nodiscard]] bool useDraftModel() const + { + return mUseDraftModel; + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -1885,7 +1903,8 @@ protected: // Number of tokens already in KV cache before context phase. // A value > 0 indicates cached KV cache blocks were reused. // Up to inputLen - 1 tokens can be reused. - SizeType32 mPrepopulatedPromptLen{0}; + SizeType32 mPrepopulatedPromptLenTarget{0}; + SizeType32 mPrepopulatedPromptLenDraft{0}; SizeType32 mMaxSentTokenLen; @@ -1916,7 +1935,8 @@ protected: // The size of the context chunk must be multiple of the KV-Cache block size except the last one. // Value `0` means Chunked-Context is disabled. SizeType32 mContextChunkSize{0}; - SizeType32 mContextCurrentPosition{0}; + SizeType32 mContextCurrentPositionTarget{0}; + SizeType32 mContextCurrentPositionDraft{0}; std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen] VecLogProbs mCumLogProbs; // [beamSize] @@ -2017,6 +2037,8 @@ protected: bool mIsDummyRequest{false}; + bool mUseDraftModel{false}; + private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) { @@ -2027,7 +2049,7 @@ private: // Scatter the input tokens to other beam mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens); - mLastTokens = VecTokens(mSamplingConfig.beamWidth); + mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back()); // Init mUniqueTokens VecUniqueTokens uniqueTokens{inputTokens.size()}; @@ -2347,6 +2369,9 @@ public: void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager); void moveLoraWeightsToGpu(runtime::BufferManager const& manager); + + // Remove LoRA weights and LoRA config tensors + void removeLoraTensors(); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/common/assert.h b/cpp/include/tensorrt_llm/common/assert.h index e7e24bf549..0e916b7746 100644 --- a/cpp/include/tensorrt_llm/common/assert.h +++ b/cpp/include/tensorrt_llm/common/assert.h @@ -16,25 +16,8 @@ #pragma once -#include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/common/tllmException.h" -#include <string> - -namespace tensorrt_llm::common -{ -[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, char const* info) -{ - throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info).c_str()); -} - -[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "") -{ - throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()).c_str()); -} - -} // namespace tensorrt_llm::common - class DebugConfig { public: @@ -86,12 +69,3 @@ public: __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__).c_str()); \ } \ } while (0) - -#define TLLM_THROW(...) \ - do \ - { \ - throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \ - } while (0) - -#define TLLM_WRAP(ex) \ - NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what()) diff --git a/cpp/include/tensorrt_llm/common/quantization.h b/cpp/include/tensorrt_llm/common/quantization.h index 836faa258f..50aae114e0 100644 --- a/cpp/include/tensorrt_llm/common/quantization.h +++ b/cpp/include/tensorrt_llm/common/quantization.h @@ -122,6 +122,16 @@ public: return QuantMode(BaseType(1u) << 14); } + static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept + { + return QuantMode(BaseType(1u) << 15); + } + + static constexpr QuantMode w4a16Mxfp4() noexcept + { + return QuantMode(BaseType(1u) << 16); + } + constexpr BaseType value() const noexcept { return mValue; @@ -202,6 +212,16 @@ public: return isSet(w4a8Mxfp4Fp8()); } + constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept + { + return isSet(w4a8Mxfp4Mxfp8()); + } + + constexpr bool hasW4a16Mxfp4() const noexcept + { + return isSet(w4a16Mxfp4()); + } + constexpr bool hasKvCacheQuant() const noexcept { return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache(); @@ -209,7 +229,8 @@ public: static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken, bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq, - bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8) + bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8, + bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4) { QuantMode quantMode{}; if (quantizeWeights) @@ -278,25 +299,35 @@ public: quantMode += w4a8Mxfp4Fp8(); } + if (useW4a8Mxfp4Mxfp8) + { + quantMode += w4a8Mxfp4Mxfp8(); + } + + if (useW4a16Mxfp4) + { + quantMode += w4a16Mxfp4(); + } + return quantMode; } static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) { - return fromDescription( - true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false); + return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false, + false, false, false, false); } static constexpr QuantMode useQServe(bool perGroup) { - return fromDescription( - true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false); + return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false, + false, false, false); } static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) { return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false, - false, false, false); + false, false, false, false, false); } static QuantMode const fromQuantAlgo( @@ -353,28 +384,38 @@ public: } else if (quantAlgo == "FP8") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, true, false, false, false, false, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false, + false, false, false, false, false); } else if (quantAlgo == "FP8_ROWWISE") { - quantMode = fromDescription( - false, false, true, true, false, false, false, false, false, true, false, false, false, false); + quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false, + false, false, false, false); } else if (quantAlgo == "FP4") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, true, false, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + true, false, false, false, false); } else if (quantAlgo == "FP8_BLOCK_SCALES") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, false, true, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, true, false, false, false); } else if (quantAlgo == "W4A8_MXFP4_FP8") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, false, false, true); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, true, false, false); + } + else if (quantAlgo == "W4A8_MXFP4_MXFP8") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, false, true, false); + } + else if (quantAlgo == "W4A16_MXFP4") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, true); } if (kvCacheQuantAlgo == "INT8") diff --git a/cpp/include/tensorrt_llm/common/stringUtils.h b/cpp/include/tensorrt_llm/common/stringUtils.h index e806c71452..a4803cba37 100644 --- a/cpp/include/tensorrt_llm/common/stringUtils.h +++ b/cpp/include/tensorrt_llm/common/stringUtils.h @@ -74,6 +74,28 @@ void printElement(std::ostream& os, std::tuple<Args...> const& t) printTupleImpl(os, t, std::index_sequence_for<Args...>{}); } +class va_list_guard +{ +public: + explicit va_list_guard(va_list& args) + : mArgs(args) + { + } + + ~va_list_guard() + { + va_end(mArgs); + } + + va_list_guard(va_list_guard const&) = delete; + va_list_guard& operator=(va_list_guard const&) = delete; + va_list_guard(va_list_guard&&) = delete; + va_list_guard& operator=(va_list_guard&&) = delete; + +private: + va_list& mArgs; +}; + } // namespace // Override operator<< for any tuple @@ -117,6 +139,8 @@ inline std::string fmtstr(char const* format, ...) va_list args; va_start(args, format); + va_list_guard args_guard(args); + fmtstr_( format, [](void* target, size_t count) -> char* @@ -131,7 +155,6 @@ inline std::string fmtstr(char const* format, ...) return str->data(); }, &result, args); - va_end(args); return result; } diff --git a/cpp/include/tensorrt_llm/common/tllmException.h b/cpp/include/tensorrt_llm/common/tllmException.h index 15a1a77019..b24e6230fd 100644 --- a/cpp/include/tensorrt_llm/common/tllmException.h +++ b/cpp/include/tensorrt_llm/common/tllmException.h @@ -16,11 +16,22 @@ #pragma once +#include "tensorrt_llm/common/stringUtils.h" + #include <array> #include <cstddef> #include <stdexcept> #include <string> +#define TLLM_THROW(...) \ + do \ + { \ + throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \ + } while (0) + +#define TLLM_WRAP(ex) \ + NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what()) + #define NEW_TLLM_EXCEPTION(...) \ tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__).c_str()) @@ -45,4 +56,14 @@ private: int mNbFrames; }; +[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, char const* info) +{ + throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info).c_str()); +} + +[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "") +{ + throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()).c_str()); +} + } // namespace tensorrt_llm::common diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 6d592654ff..0a58298c27 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1001,6 +1001,7 @@ public: std::optional<FloatType> const& crossKvCacheFraction = std::nullopt, std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0, bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false, + SizeType32 attentionDpEventsGatherPeriodMs = 5, std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt); [[nodiscard]] bool getEnableBlockReuse() const; @@ -1016,6 +1017,7 @@ public: [[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const; [[nodiscard]] size_t getEventBufferMaxSize() const; [[nodiscard]] bool getUseUvm() const; + [[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const; void setEnableBlockReuse(bool enableBlockReuse); void setEnablePartialReuse(bool enablePartialReuse); @@ -1030,6 +1032,7 @@ public: void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority); void setEventBufferMaxSize(size_t eventBufferMaxSize); void setUseUvm(bool useUvm); + void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs); void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults); @@ -1085,6 +1088,9 @@ private: /// @brief Whether to use UVM for the KV cache. bool mUseUvm; + + /// @brief The period in milliseconds to gather attention DP events across ranks + SizeType32 mAttentionDpEventsGatherPeriodMs; }; /// @brief Configuration class for the runtime perf knobs @@ -1702,6 +1708,12 @@ struct KVCacheUpdatedData explicit KVCacheUpdatedData(IdType blockHash) : blockHash{blockHash} {}; + explicit KVCacheUpdatedData(IdType blockHash, std::optional<KVCacheEventDiff<SizeType32>> cacheLevel, + std::optional<KVCacheEventDiff<SizeType32>> priority) + : blockHash{blockHash} + , cacheLevel{cacheLevel} + , priority{priority} {}; + KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue) { cacheLevel = KVCacheEventDiff<SizeType32>{oldValue, newValue}; @@ -1726,8 +1738,8 @@ using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVC struct KVCacheEvent { - - KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize); + KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize, + std::optional<SizeType32> attentionDpRank = std::nullopt); /// @brief The unique id of this event IdType eventId; @@ -1735,6 +1747,8 @@ struct KVCacheEvent KVCacheEventData data; /// @brief The sliding window size SizeType32 windowSize; + /// @brief The attention DP rank of the event, if applicable + std::optional<SizeType32> attentionDpRank; }; /// @brief Exposes a limited set of KV cache manager functionalities diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index b2ecfc66c8..c370a65235 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -302,6 +302,53 @@ public: [[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec( std::vector<char>& buffer); + // KVCacheEvent deque + [[nodiscard]] static std::vector<char> serialize(std::deque<KVCacheEvent> const& kvCacheEvents); + [[nodiscard]] static std::deque<KVCacheEvent> deserializeKVCacheEvents(std::vector<char>& buffer); + + // KVCacheEvent + [[nodiscard]] static size_t serializedSize(KVCacheEvent const& event); + static void serialize(KVCacheEvent const& event, std::ostream& os); + [[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is); + + // KVCacheCreatedData + [[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data); + static void serialize(KVCacheCreatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is); + + // KVCacheStoredData + [[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data); + static void serialize(KVCacheStoredData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is); + + // KVCacheStoredBlockData + [[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data); + static void serialize(KVCacheStoredBlockData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is); + + // KVCacheRemovedData + [[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data); + static void serialize(KVCacheRemovedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is); + + // KVCacheEventDiff + template <typename T> + [[nodiscard]] static size_t serializedSize(KVCacheEventDiff<T> const& data); + template <typename T> + static void serialize(KVCacheEventDiff<T> const& data, std::ostream& os); + template <typename T> + [[nodiscard]] static KVCacheEventDiff<T> deserializeKVCacheEventDiff(std::istream& is); + + // KVCacheUpdateData + [[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data); + static void serialize(KVCacheUpdatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is); + + // UniqueToken + [[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token); + static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os); + [[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is); + // String static std::string deserializeString(std::istream& is); diff --git a/cpp/include/tensorrt_llm/runtime/decoderState.h b/cpp/include/tensorrt_llm/runtime/decoderState.h index e4fe9c3801..95d7ff0ffa 100644 --- a/cpp/include/tensorrt_llm/runtime/decoderState.h +++ b/cpp/include/tensorrt_llm/runtime/decoderState.h @@ -51,13 +51,13 @@ public: DecoderState(); //! @brief Setup buffers for the decoder excluding speculative decoding. - void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, + void setup(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager); //! @brief Setup buffers for the cache indirection. //! @details This is used for beam search on pipeline parallel ranks without a decoder. - void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, + void setupCacheIndirection(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, BufferManager const& bufferManager); //! @brief Setup buffers for speculative decoding. @@ -134,7 +134,7 @@ public: //! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu [[nodiscard]] TensorPtr getAcceptedPackedPaths() const; - [[nodiscard]] SizeType32 getMaxBatchSize() const; + [[nodiscard]] SizeType32 getMaxNumSequences() const; [[nodiscard]] SizeType32 getMaxBeamWidth() const; @@ -173,6 +173,11 @@ public: //! @brief Workspace for beam search in streaming mode. [[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const; + //! @brief Set the beam width for a specific request in the batch. + //! @param batchIdx The index of the request in the batch. + //! @param beamWidth The beam width for the specified request. + void setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth); + //! @brief Cache indirection input for beam search. [[nodiscard]] TensorPtr getCacheIndirectionInput() const; @@ -187,10 +192,10 @@ public: //! @param generationSteps The generation steps for all requests in the batch. void setGenerationSteps(std::vector<SizeType32> const& generationSteps); - //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots. [[nodiscard]] DecodingInput& getJointDecodingInput() const; - //! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots. [[nodiscard]] DecodingOutput& getJointDecodingOutput() const; private: @@ -209,13 +214,13 @@ private: SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager); - SizeType32 mMaxBatchSize{}; + SizeType32 mMaxNumSequences{}; SizeType32 mMaxBeamWidth{}; SizeType32 mMaxSequenceLength{}; - //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots. DecodingInputPtr mJointDecodingInput; - //! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots. DecodingOutputPtr mJointDecodingOutput; //! @brief Workspace for beam search in streaming mode. diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h index 90690c90fc..7e0cc1bb56 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h @@ -71,7 +71,7 @@ public: = 0; static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype, - size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, + size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream, std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule = nullptr); }; @@ -84,7 +84,7 @@ public: using CudaStreamPtr = BufferManager::CudaStreamPtr; using TensorPtr = std::shared_ptr<ITensor>; - GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, + GptDecoder(executor::DecodingMode const& mode, size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream, std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr); @@ -114,7 +114,7 @@ private: SamplingConfig mSamplingConfig; - size_t mMaxBatchSize; + size_t mMaxNumSequences; size_t mVocabSize; size_t mVocabSizePadded; @@ -122,7 +122,7 @@ private: }; inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype, - size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, + size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream, std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule) { @@ -130,10 +130,10 @@ inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode c { case nvinfer1::DataType::kFLOAT: return std::make_unique<GptDecoder<float>>( - mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); + mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); case nvinfer1::DataType::kHALF: return std::make_unique<GptDecoder<half>>( - mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); + mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); default: TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast<int>(dtype)); return nullptr; diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index d5dfe9b7b1..d0a9e726d1 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -47,7 +47,7 @@ public: explicit GptDecoderBatched(CudaStreamPtr stream); - void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, + void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override; void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 327af71f8a..606ba3c98a 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -86,7 +86,7 @@ public: using TensorPtr = std::shared_ptr<ITensor>; //! @brief Setup the decoder before calling `forward()` - virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, + virtual void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) = 0; diff --git a/cpp/include/tensorrt_llm/runtime/request.h b/cpp/include/tensorrt_llm/runtime/request.h index 1861ea8431..e8f851b7d7 100644 --- a/cpp/include/tensorrt_llm/runtime/request.h +++ b/cpp/include/tensorrt_llm/runtime/request.h @@ -31,26 +31,16 @@ public: using TensorPtr = ITensor::SharedPtr; using BufferPtr = IBuffer::SharedPtr; - explicit Request(TensorConstPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt, - std::optional<SizeType32> endId = std::nullopt) - : ids{std::move(ids)} - , inputLen(inputLen) - , maxNewTokens{maxNewTokens} - , endId{endId} + explicit Request(SizeType32 inputLen) + : inputLen(inputLen) { } //! Mandatory parameters - TensorConstPtr ids; // The input sequence of token ids, [inputSeqLen], on gpu SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps // optional parameters - std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request - std::optional<SizeType32> endId; // end token id SizeType32 generatedTokensPerEngineStep{1}; // - TensorPtr embeddingBias; // [vocabSizePadded], on gpu - TensorPtr badWordsList; // [2, badWordsLength] on gpu - TensorPtr stopWordsList; // [2, stopWordsLength] on gpu //! Optional parameters for speculative decoding BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu diff --git a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h index 4443d422ab..32c086c84e 100644 --- a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h +++ b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h @@ -68,6 +68,10 @@ enum class MpiTag : int // LogitsThread kSpecDecLogitsId = 129, kSpecDecLogitsData = 1025, + + // KvCacheEventManager + kKvCacheEventSize = 1026, + kKvCacheEvent = 1027 }; } // namespace tensorrt_llm::mpi diff --git a/cpp/kernels/fmha_v2/fmha_test.py b/cpp/kernels/fmha_v2/fmha_test.py index 1661db3e25..21486b00ea 100644 --- a/cpp/kernels/fmha_v2/fmha_test.py +++ b/cpp/kernels/fmha_v2/fmha_test.py @@ -55,7 +55,7 @@ def getSMVersion(): ids=["fp16", "bf16", "fp16-fp32", "e4m3"]) @pytest.mark.parametrize('flag', [ "-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv", - "-softcapping-scale-bmm1 30", "-contiguous-q-kv" + "-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks" ]) @pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"]) def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel): @@ -122,8 +122,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel): f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) - # alibi and softcapping-scale-bmm1 are mutually exclusive. - if '-softcapping-scale-bmm1' not in flag: + # alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks. + if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag: subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, @@ -183,6 +183,23 @@ def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout): shell=True, check=True) + # For chunked prefill, we need to enable -save-softmax (dtype: bf16, sm90, layout: paged-kv or separate-q-k-v). + if dtype == "-bf16" and input_layout in [ + "-paged-kv", "-separate-q-k-v" + ]: + # padding mask + subprocess.run( + f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \ + {epsilon} {input_layout} -save-softmax", + shell=True, + check=True) + # causal mask + subprocess.run( + f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \ + -causal-mask {epsilon} {input_layout} -save-softmax", + shell=True, + check=True) + @pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"], ids=["bf16", "e4m3", "e4m3-bf16"]) diff --git a/cpp/kernels/fmha_v2/setup.py b/cpp/kernels/fmha_v2/setup.py index e7a3986455..8434d4225d 100644 --- a/cpp/kernels/fmha_v2/setup.py +++ b/cpp/kernels/fmha_v2/setup.py @@ -1971,8 +1971,7 @@ def selected_mask_types(kspec): sliding_or_chunked_causal_mask = '0' custom_mask = '0' elif (kspec.head_size, kspec.head_size_v) == (192, 128): - # MLA context phase only needs causal mask now - padding_mask = '0' + # MLA context phase only needs causal mask and padding mask (for chunked prefill) now sliding_or_chunked_causal_mask = '0' custom_mask = '0' elif (kspec.head_size, kspec.head_size_v) == (576, 512): @@ -2311,8 +2310,7 @@ def get_api_code(specs_names): # whether support alibi or not. if kspec.warp_specialization: il_check += '&& params.has_alibi ' if kspec.alibi else '&& !params.has_alibi ' - if kspec.input_layout.value == InputLayout.CONTIGUOUS_Q_KV: - il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr ' + il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr ' # use enable_attn_logit_softcapping or not. il_check += '&& enable_attn_logit_softcapping ' if kspec.enable_attn_logit_softcapping else '&& !enable_attn_logit_softcapping ' # check sage block sizes @@ -3653,104 +3651,110 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): # alibi and enable_attn_logit_softcapping shouldn't be used together. if alibi and enable_attn_logit_softcapping: continue - if input_layout != InputLayout.CONTIGUOUS_Q_KV and return_softmax: - continue - # only specify - specs.append( - kernel_spec( - sm=sm, - sm_mma=90, - dtype=dtype, - seq_len=0, # support any sequence length - head_size=[32, 40, 48, 64], - warps_m=4, #4x1 warpgroups - warps_n=1, - version=2, - interleaved=False, - ldgsts_q= - False, # for Hopper kernels, ldgsts = False signals TMA usage. - ldgsts_k=False, - ldgsts_v=False, - share_smem_k_v=False, - loop_step=64, - q_tile_buffers=1, # only used by warp specialized kernels - has_noloop=0, - noloop_step=64, - kv_loop_step=256, - kv_tile_buffers=2, # only used by warp specialized kernels - unroll_threshold=1, - has_scale_max=False, - flash_attention=True, - warp_specialization=True, - alibi=alibi, - enable_attn_logit_softcapping=enable_attn_logit_softcapping, - return_softmax_stats=return_softmax, - scheduling_mode=scheduling_mode, - input_layout=input_layout)) + # for normal attention, we only need contiguous kv as input layout when returning softmax. + skip_combination = return_softmax and (input_layout + != InputLayout.CONTIGUOUS_Q_KV) + # for context mla, we need paged kv or separate qkv as input layout when returning softmax. + skip_mla_combination = return_softmax and ( + input_layout != InputLayout.Q_PAGED_KV + and input_layout != InputLayout.SEPARATE_Q_K_V) + if not skip_combination: + # only specify + specs.append( + kernel_spec( + sm=sm, + sm_mma=90, + dtype=dtype, + seq_len=0, # support any sequence length + head_size=[32, 40, 48, 64], + warps_m=4, #4x1 warpgroups + warps_n=1, + version=2, + interleaved=False, + ldgsts_q= + False, # for Hopper kernels, ldgsts = False signals TMA usage. + ldgsts_k=False, + ldgsts_v=False, + share_smem_k_v=False, + loop_step=64, + q_tile_buffers=1, # only used by warp specialized kernels + has_noloop=0, + noloop_step=64, + kv_loop_step=256, + kv_tile_buffers=2, # only used by warp specialized kernels + unroll_threshold=1, + has_scale_max=False, + flash_attention=True, + warp_specialization=True, + alibi=alibi, + enable_attn_logit_softcapping=enable_attn_logit_softcapping, + return_softmax_stats=return_softmax, + scheduling_mode=scheduling_mode, + input_layout=input_layout)) - specs.append( - kernel_spec( - sm=sm, - sm_mma=90, - dtype=dtype, - seq_len=0, # support any sequence length - head_size=[72, 80, 96, 104, 128], - warps_m=4, #4x1 warpgroups - warps_n=1, - version=2, - interleaved=False, - ldgsts_q= - False, # for Hopper kernels, ldgsts = False signals TMA usage. - ldgsts_k=False, - ldgsts_v=False, - share_smem_k_v=False, - loop_step=64, - q_tile_buffers=1, # only used by warp specialized kernels - has_noloop=0, - noloop_step=64, - kv_loop_step=128, - kv_tile_buffers=2, # only used by warp specialized kernels - unroll_threshold=1, - has_scale_max=False, - flash_attention=True, - warp_specialization=True, - alibi=alibi, - enable_attn_logit_softcapping=enable_attn_logit_softcapping, - return_softmax_stats=return_softmax, - scheduling_mode=scheduling_mode, - input_layout=input_layout)) + specs.append( + kernel_spec( + sm=sm, + sm_mma=90, + dtype=dtype, + seq_len=0, # support any sequence length + head_size=[72, 80, 96, 104, 128], + warps_m=4, #4x1 warpgroups + warps_n=1, + version=2, + interleaved=False, + ldgsts_q= + False, # for Hopper kernels, ldgsts = False signals TMA usage. + ldgsts_k=False, + ldgsts_v=False, + share_smem_k_v=False, + loop_step=64, + q_tile_buffers=1, # only used by warp specialized kernels + has_noloop=0, + noloop_step=64, + kv_loop_step=128, + kv_tile_buffers=2, # only used by warp specialized kernels + unroll_threshold=1, + has_scale_max=False, + flash_attention=True, + warp_specialization=True, + alibi=alibi, + enable_attn_logit_softcapping=enable_attn_logit_softcapping, + return_softmax_stats=return_softmax, + scheduling_mode=scheduling_mode, + input_layout=input_layout)) - specs.append( - kernel_spec( - sm=sm, - sm_mma=90, - dtype=dtype, - seq_len=0, # support any sequence length - head_size=[160, 192, 256], - warps_m=4, #4x1 warpgroups - warps_n=1, - version=2, - interleaved=False, - ldgsts_q= - False, # for Hopper kernels, ldgsts = False signals TMA usage. - ldgsts_k=False, - ldgsts_v=False, - share_smem_k_v=False, - loop_step=64, - q_tile_buffers=1, # only used by warp specialized kernels - has_noloop=0, - noloop_step=64, - kv_loop_step=64, - kv_tile_buffers=2, # only used by warp specialized kernels - unroll_threshold=1, - has_scale_max=False, - flash_attention=True, - warp_specialization=True, - alibi=alibi, - enable_attn_logit_softcapping=enable_attn_logit_softcapping, - return_softmax_stats=return_softmax, - scheduling_mode=scheduling_mode, - input_layout=input_layout)) + specs.append( + kernel_spec( + sm=sm, + sm_mma=90, + dtype=dtype, + seq_len=0, # support any sequence length + head_size=[160, 192, 256], + warps_m=4, #4x1 warpgroups + warps_n=1, + version=2, + interleaved=False, + ldgsts_q= + False, # for Hopper kernels, ldgsts = False signals TMA usage. + ldgsts_k=False, + ldgsts_v=False, + share_smem_k_v=False, + loop_step=64, + q_tile_buffers=1, # only used by warp specialized kernels + has_noloop=0, + noloop_step=64, + kv_loop_step=64, + kv_tile_buffers=2, # only used by warp specialized kernels + unroll_threshold=1, + has_scale_max=False, + flash_attention=True, + warp_specialization=True, + alibi=alibi, + enable_attn_logit_softcapping=enable_attn_logit_softcapping, + return_softmax_stats=return_softmax, + scheduling_mode=scheduling_mode, + input_layout=input_layout)) ''' smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS + (kv_step * d + kv_step * dv) * kv_buffers) * ele_size @@ -3762,38 +3766,39 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): Then for fp16/bf16 context MLA, d remains 192 (192 * 2 = 128 * 3), and dv remains 128, if kv_step = 128, then smem_size = 208 KB, smem is fully utilized. ''' - specs.append( - kernel_spec( - sm=sm, - sm_mma=90, - dtype=dtype, - seq_len=0, # support any sequence length - head_size=192, - head_size_v=128, - warps_m=4, #4x1 warpgroups - warps_n=1, - version=2, - interleaved=False, - ldgsts_q= - False, # for Hopper kernels, ldgsts = False signals TMA usage. - ldgsts_k=False, - ldgsts_v=False, - share_smem_k_v=False, - loop_step=64, - q_tile_buffers=1, # only used by warp specialized kernels - has_noloop=0, - noloop_step=64, - kv_loop_step=128, - kv_tile_buffers=2, # only used by warp specialized kernels - unroll_threshold=1, - has_scale_max=False, - flash_attention=True, - warp_specialization=True, - alibi=alibi, - enable_attn_logit_softcapping=enable_attn_logit_softcapping, - return_softmax_stats=return_softmax, - scheduling_mode=scheduling_mode, - input_layout=input_layout)) + if not skip_mla_combination: + specs.append( + kernel_spec( + sm=sm, + sm_mma=90, + dtype=dtype, + seq_len=0, # support any sequence length + head_size=192, + head_size_v=128, + warps_m=4, #4x1 warpgroups + warps_n=1, + version=2, + interleaved=False, + ldgsts_q= + False, # for Hopper kernels, ldgsts = False signals TMA usage. + ldgsts_k=False, + ldgsts_v=False, + share_smem_k_v=False, + loop_step=64, + q_tile_buffers=1, # only used by warp specialized kernels + has_noloop=0, + noloop_step=64, + kv_loop_step=128, + kv_tile_buffers=2, # only used by warp specialized kernels + unroll_threshold=1, + has_scale_max=False, + flash_attention=True, + warp_specialization=True, + alibi=alibi, + enable_attn_logit_softcapping=enable_attn_logit_softcapping, + return_softmax_stats=return_softmax, + scheduling_mode=scheduling_mode, + input_layout=input_layout)) # Note this will be used in TRT-LLM. diff --git a/cpp/kernels/fmha_v2/src/fmha/fragment.h b/cpp/kernels/fmha_v2/src/fmha/fragment.h index 4f1202d41e..8be16df163 100644 --- a/cpp/kernels/fmha_v2/src/fmha/fragment.h +++ b/cpp/kernels/fmha_v2/src/fmha/fragment.h @@ -1904,8 +1904,7 @@ struct Softmax_saver , softmax_sum_ptr_(reinterpret_cast<char*>(params.softmax_stats_ptr)) , softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) { - size_t softmax_max_off = sizeof(float) * params.b * params.s * params.h; - softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr) + softmax_max_off; + softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr); int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; @@ -1917,9 +1916,9 @@ struct Softmax_saver store_softmax_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0); // assume fixed seq length for the batch - size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float); - softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes; + size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float) * 2; softmax_max_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes; + softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes + sizeof(float); }; inline __device__ void store(int q_loop, float* p_sum, float* p_max) @@ -1938,19 +1937,19 @@ struct Softmax_saver int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA; if (row0_ + row_offset < actual_q_len_) { - fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0); fmha::stg(softmax_max_ptr_ + row_offset * softmax_stats_stride_in_bytes_, max0); + fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0); } if (row0_ + row_offset + 8 < actual_q_len_) { - fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1); fmha::stg(softmax_max_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, max1); + fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1); } } } } - // ptr + // ptr (total_token_q, h, 2) float char* softmax_sum_ptr_ = nullptr; char* softmax_max_ptr_ = nullptr; diff --git a/cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h b/cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h index c89a9453ee..a6433856f5 100644 --- a/cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h +++ b/cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h @@ -465,8 +465,7 @@ struct Softmax_saver_tma , softmax_sum_ptr_(reinterpret_cast<char*>(params.softmax_stats_ptr)) , softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) { - size_t softmax_max_off = sizeof(float) * params.b * params.s * params.h; - softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr) + softmax_max_off; + softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr); int warp = (threadIdx.x % 128) / Cta_tile::THREADS_PER_WARP; int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; // MMA row0 index (8x4 thread layout) @@ -474,9 +473,9 @@ struct Softmax_saver_tma int sum_s = params.is_s_padded ? params.s * head_info.bidb : params.cu_q_seqlens[head_info.bidb]; int token_id = sum_s * params.h + head_info.bidh; - size_t const bh_offset = token_id * sizeof(float) + local_q_tile_offset_ * softmax_stats_stride_in_bytes_; - softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_; + size_t const bh_offset = token_id * sizeof(float) * 2 + local_q_tile_offset_ * softmax_stats_stride_in_bytes_; softmax_max_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_; + softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_ + sizeof(float); }; inline __device__ void store(float* p_sum, float* p_max, float sqrt_d, int row_offset, bool valid_run) @@ -487,7 +486,7 @@ struct Softmax_saver_tma int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; if (lane % 4 < 2) { - values = p_sum[lane % 2] == 0.f ? 1.f : 1.0f / p_sum[lane % 2]; + values = p_sum[lane % 2]; } else { diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h index 65e56dbf5d..eed6f852da 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h @@ -326,9 +326,6 @@ struct Compute uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]); Compute_tile_o ctile_o(0, smem_v); - // BMM2 epilogue - Tile_o_epilogue tile_o_epilogue(params); - // Mutex between two compute groups. OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER); // Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions). @@ -368,6 +365,9 @@ struct Compute sage_scale_row = head_info.bidb * params.h + head_info.bidh; } + // BMM2 epilogue + Tile_o_epilogue tile_o_epilogue(params, head_info); + int q_step_idx = warpgroup_id; // Compute work. @@ -490,7 +490,7 @@ struct Compute if (valid_run) { // Final step's update. - tile_o_epilogue.scale(ctile_o, p_sum); + tile_o_epilogue.scale(ctile_o, p_max, p_sum); // Store o_tile to gmem. gmem_o.store(ctile_o.acc_); } diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h index 12e73bedf1..c8e3c318d6 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h @@ -500,7 +500,7 @@ struct DMA int const num_valid_kv_blocks = (actual_kv_seqlen + params.paged_kv_cache.mTokensPerBlock - 1) >> params.paged_kv_cache.mTokensPerBlockLog2; - for (int q_step_idx = 0; q_step_idx < q_steps; q_step_idx += 2) + for (int q_step_idx = 0; q_step_idx < q_steps && actual_kv_seqlen > 0; q_step_idx += 2) { load_q(bidh, q_step_idx * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[0], cbw0); load_q(bidh, (q_step_idx + 1) * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[1], cbw1); diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h index 217e8c0872..99ea1643cd 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h @@ -454,7 +454,7 @@ struct Softmax_base #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - uint32_t const scale = float_to_half2(correction_[mi]); + const uint32_t scale = float_to_half2(correction_[mi]); // Assume only N has multiple MMAs (MMAS_M = 1). // MMAS_N > 1 when N dimension is split. @@ -477,9 +477,15 @@ struct Softmax_base } // BMM1 scale. - uint32_t const scale_bmm1_; + const uint32_t scale_bmm1_; // BMM1 softcapping scale. float const softcapping_scale_bmm1_; + + // The sliding window size. + int const sliding_window_size_; + // The log2 attention chunk size. + int const log2_chunked_attention_size_; + // The thread idx in the warp group. int tidx_; // The col index for the mma thread layout. @@ -487,15 +493,10 @@ struct Softmax_base // The row index for the mma thread layout. int quad_row_; - // The sliding window size. - int const sliding_window_size_; - // The log2 attention chunk size. - int const log2_chunked_attention_size_; - // The packed mask ptr. uint32_t const* packed_mask_ptr_; // The packed mask k-dim stride in bytes; - int64_t const params_packed_mask_stride_in_bytes_; + const int64_t params_packed_mask_stride_in_bytes_; // Unpacked BMM1 output buffer. float elt_[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2]; @@ -1072,20 +1073,53 @@ struct Tile_o_epilogue_base // The MMA tile for the BMM2. using Mma_tile_o = typename Kernel_traits::Mma_tile_o; - template <typename Params> - inline __device__ Tile_o_epilogue_base(Params const& params) + // Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs). + enum { - ; // nothing to construct. + EXP2F_OPTIMIZATION = Kernel_traits::EXP2F_OPTIMIZATION }; + template <typename Params, typename Block_info> + inline __device__ Tile_o_epilogue_base(Params const& params, Block_info& block_info) + { + has_attention_sink_ = params.attention_sinks != nullptr; + head_idx_ = block_info.bidh; + attention_sink_ = has_attention_sink_ ? params.attention_sinks[block_info.bidh] : 0.f; + // It is only need when the exp2f optimization is enabled, so params.scale_bmm1 is always float. + scale_bmm1_f_ = reinterpret_cast<float const&>(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1); + }; + + // The attention sinks. + inline __device__ void add_attention_sink(float& sum, float max) + { + if (has_attention_sink_) + { + // The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled. + if constexpr (EXP2F_OPTIMIZATION) + { + sum += exp2f(attention_sink_ * M_LOG2E - max * scale_bmm1_f_); + } + else + { + sum += expf(attention_sink_ - max); + } + } + } + // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi]; + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + add_attention_sink(global_sum_mi, global_max[mi]); + // The scale. + float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi; // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1096,12 +1130,21 @@ struct Tile_o_epilogue_base { float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi); float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1); - reg0 *= global_sum[mi]; - reg1 *= global_sum[mi]; + reg0 *= scale; + reg1 *= scale; } } } } + + // Whether the attention sink is enabled. + bool has_attention_sink_ = false; + // The attention sink value. + float attention_sink_ = 0.f; + // The float scale of bmm1 outputs. + float scale_bmm1_f_ = 1.f; + // The head idx. + int head_idx_ = 0; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1138,14 +1181,21 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits> using Base::Tile_o_epilogue_base; // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi]; - uint32_t const scale = float_to_half2(global_sum[mi]); + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + this->add_attention_sink(global_sum_mi, global_max[mi]); + // The scale. + float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi; + // The scale. + const uint32_t scale_h = float_to_half2(scale); // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1155,7 +1205,7 @@ struct Tile_o_epilogue<Hopper_hgmma_fp16_traits, Kernel_traits> for (int ni = 0; ni < Mma_tile_o::CORES_N; ni++) { uint32_t& reg = ctile_o.acc_[0][mma_ni].reg(ni * Mma_tile_o::CORES_M + mi); - reg = hmul2(reg, scale); + reg = hmul2(reg, scale_h); } } } @@ -1215,27 +1265,58 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits> // The MMA tile for the BMM2. using Mma_tile_o = typename Base::Mma_tile_o; + // Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs). + enum + { + EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION + }; + // Ctor. - template <typename Params> - inline __device__ Tile_o_epilogue(Params const& params) - : Base(params) + template <typename Params, typename Block_info> + inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info) + : Base(params, block_info) , scale_bmm2_(*params.scale_bmm2_d) { } + // Add the attention sink to the global sum. + inline __device__ void add_attention_sink(float& sum, float max) + { + if (this->has_attention_sink_) + { + // The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled. + // Take the log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum. + float quant_scale_in_log2 = log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE); + if constexpr (EXP2F_OPTIMIZATION) + { + sum += exp2f(this->attention_sink_ * M_LOG2E - max * this->scale_bmm1_f_ + quant_scale_in_log2); + } + else + { + sum += expf(this->attention_sink_ - max + quant_scale_in_log2); + } + } + } + // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + add_attention_sink(global_sum_mi, global_max[mi]); #ifdef UNIFIED_EPILOGUE_SCALE // Descaling factor float const scale_bmm2_f_ = reinterpret_cast<float&>(scale_bmm2_); - global_sum[mi] = global_sum[mi] == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum[mi]; + // The scale. + float scale = global_sum_mi == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum_mi; #else - global_sum[mi] = global_sum[mi] == 0.f ? 1.0f : 1.0f / global_sum[mi]; + float scale = global_sum_mi == 0.f ? 1.0f : 1.0f / global_sum_mi; #endif // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1246,8 +1327,8 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits> { float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi); float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1); - reg0 *= global_sum[mi]; - reg1 *= global_sum[mi]; + reg0 *= scale; + reg1 *= scale; } } } diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp b/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp index 6d9811ac07..525171963a 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp @@ -29,33 +29,36 @@ using Kv_block_array = fmha::Kv_block_array; //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n, +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// +void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, + float softcapping_scale_bmm1, int warps_n, bool has_alibi); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,11 +84,11 @@ void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int //////////////////////////////////////////////////////////////////////////////////////////////////// -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type, +void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1, - void* qkv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, - void* cu_q_seqlens_d, size_t const b, size_t const s, size_t const h, size_t const d, size_t const dv, - int const runs, int const warps_m, int const warps_n, bool const has_alibi) + void* qkv_d, void* vt_d, void* mask_d, void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d, + void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, const size_t h, const size_t d, + const size_t dv, int const runs, int const warps_m, int const warps_n, bool const has_alibi) { cudaStream_t stream = 0; @@ -106,28 +109,28 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty // Softmax. if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32) { - run_softmax_bf16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, + scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } else { @@ -179,7 +182,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params, // types Data_type data_type, Data_type acc_type, // sizes - size_t const b, size_t const s, size_t const h, size_t const d, size_t const packed_mask_stride, + const size_t b, const size_t s, const size_t h, const size_t d, const size_t packed_mask_stride, // device pointers void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d, // scale factors @@ -235,17 +238,17 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params, //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline void set_params(bert::Fused_multihead_attention_params_v2& params, Launch_params const launch_params, +static inline void set_params(bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params, // types Data_type data_type, Data_type acc_type, Data_type output_dtype, // attention input layout Attention_input_layout input_layout, // sizes - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const h_kv, size_t const d, - size_t const dv, size_t const total, const size_t num_grouped_heads, const size_t sliding_window_size, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d, + const size_t dv, const size_t total, const size_t num_grouped_heads, const size_t sliding_window_size, const size_t chunked_attention_size, // paged kv cache block size. - size_t const tokens_per_block, + const size_t tokens_per_block, // device pointers void* qkv_packed_d, // contiguous q. @@ -261,8 +264,10 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, // offsets for different blocks in terms of the start address. int32_t* paged_block_offsets, // mask input. - void* packed_mask_d, void* cu_mask_rows_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, - void* s_d, void* softmax_stats_d, void* scale_bmm2_d, + void* packed_mask_d, void* cu_mask_rows_d, + // attention sinks. + void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, void* s_d, + void* softmax_stats_d, void* scale_bmm2_d, // scale factors float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1, // flags @@ -329,6 +334,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, // The N dimension has to be aligned. params.packed_mask_stride_in_bytes = (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8; + // Attention sinks. + params.attention_sinks = reinterpret_cast<float*>(attention_sinks_d); + #if defined(STORE_P) params.p_ptr = p_d; params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type); @@ -340,7 +348,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, #endif // defined(STORE_S) params.softmax_stats_ptr = softmax_stats_d; - params.softmax_stats_stride_in_bytes = get_size_in_bytes(h, DATA_TYPE_FP32); + params.softmax_stats_stride_in_bytes = get_size_in_bytes(h * 2, DATA_TYPE_FP32); // Set the dimensions. params.b = b; @@ -412,13 +420,13 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, size_t const s, - size_t const d, Attention_mask_type const attention_mask_type, Attention_input_layout const input_layout, +static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, const size_t s, + const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout, bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma, bool const force_non_flash_attention, bool const force_non_warp_specialization, bool const force_non_granular_tiling, bool const force_fp32_acc, // device props - cudaDeviceProp const props) + const cudaDeviceProp props) { // Set launch params to choose kernels @@ -573,6 +581,9 @@ int main(int argc, char** argv) // SageAttention block sizes int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0; + // Use attention sinks (added to the denominator of softmax) + bool use_attention_sinks = false; + // Read the parameters from the command-line. for (int ii = 1; ii < argc; ++ii) { @@ -865,21 +876,28 @@ int main(int argc, char** argv) { sage_block_size_v = strtol(argv[ii], nullptr, 10); } + else if (!strcmp(argv[ii], "-use-attention-sinks")) + { + use_attention_sinks = true; + } else { fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]); return -1; } } - if (save_softmax == true) { - if (input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) + bool is_MLA = (d == 192 && dv == 128); + if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) + || (is_MLA && input_layout != Attention_input_layout::Q_PAGED_KV + && input_layout != Attention_input_layout::SEPARATE_Q_K_V)) { - input_layout = Attention_input_layout::CONTIGUOUS_Q_KV; - printf( - "Only '--contiguous-q-kv' layout supports '-save-softmax', switched to " - "contiguous-q-kv\n"); + fprintf(stderr, + "For normal attention, Only '--contiguous-q-kv' layout supports " + "'-save-softmax'. For MLA only '-paged-kv' and '-separate-q-k-v' layout supports " + "'-save-softmax'.\n"); + exit(1); } if (data_type == DATA_TYPE_E4M3) { @@ -1043,11 +1061,11 @@ int main(int argc, char** argv) force_non_granular_tiling, force_fp32_acc, props); // The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D. - size_t const qkv_size = s * b * h * (2 * d + dv); + const size_t qkv_size = s * b * h * (2 * d + dv); // Allocate on the host. float* qkv_h = (float*) malloc(qkv_size * sizeof(float)); // The size in bytes. - size_t const qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type); + const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type); // Allocate on the device. void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes)); @@ -1057,7 +1075,7 @@ int main(int argc, char** argv) // The shape is [B, 2, S, H, D]. const size_t kv_size = b * s * h_kv * (d + dv); // The size in bytes. - size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); + const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); // Allocate on the host. void* contiguous_kv_h = malloc(kv_size_in_bytes); // Memset the buffer. @@ -1071,13 +1089,13 @@ int main(int argc, char** argv) void** kv_cache_ptrs_h = nullptr; void* kv_cache_pool_ptr = nullptr; int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr; - size_t const max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block; - size_t const num_total_blocks = b * 2 * max_blocks_per_seq; + const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block; + const size_t num_total_blocks = b * 2 * max_blocks_per_seq; kv_cache_ptrs_h = (void**) malloc(num_total_blocks * sizeof(void*)); kv_cache_block_offsets_h = (int32_t*) malloc(num_total_blocks * sizeof(int32_t)); - size_t const paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type); + const size_t paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type); FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t))); - size_t const kv_cache_pool_sz + const size_t kv_cache_pool_sz = get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type); FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_pool_ptr), kv_cache_pool_sz)); size_t ptr_index = 0; @@ -1104,7 +1122,7 @@ int main(int argc, char** argv) // Q will always be [B, S, H, Dh] with paged kv cache. void* q_d; - size_t const q_size = s * b * h * d; + const size_t q_size = s * b * h * d; FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type))); // K has [B, S, H_kv, D] with separate kv cache. @@ -1122,11 +1140,11 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t))); // The mask for dropout or any mask patterns. - size_t const mask_size = s * b * s; + const size_t mask_size = s * b * s; // Allocate on the host. float* mask_h = (float*) malloc(mask_size * sizeof(float)); // The size in bytes. - size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); + const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); // Allocate on the device. void* mask_d = nullptr; if (!skip_checks) @@ -1158,7 +1176,7 @@ int main(int argc, char** argv) v1 ? 1 : 2); // The number of threads per CTA. - size_t const threads_per_cta = warps_m * warps_n * warps_k * 32; + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m); // The number of mmas in the N dimension. @@ -1182,7 +1200,7 @@ int main(int argc, char** argv) packed_mask_size = b * mmas_m * mmas_n * threads_per_cta; } // The size in bytes. - size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); + const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); // Allocate on the host. uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes); // Set it to 0 (indicates that all elements are valid). @@ -1190,23 +1208,41 @@ int main(int argc, char** argv) // Allocate on the device. void* packed_mask_d = nullptr; + // The size of the attention sinks. + const size_t attention_sinks_size_in_bytes = h * sizeof(float); + + // The attention sinks. + void* attention_sinks_d = nullptr; + if (use_attention_sinks) + { + // Allocate on the host. + float* attention_sinks_h = (float*) malloc(attention_sinks_size_in_bytes); + // Randomly initialize the attention sinks. + random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose); + // Allocate on the device. + FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes)); + // Copy from the host to the device. + FMHA_CHECK_CUDA( + cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, cudaMemcpyDefault)); + } + // The O matrix is packed as S * B * H * D. - size_t const o_size = s * b * h * dv; + const size_t o_size = s * b * h * dv; // Allocate on the host. float* o_h = (float*) malloc(o_size * sizeof(float)); // The size in bytes. - size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type); + const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); // Allocate on the device. void* o_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); - // The softmax_stats_d vector is used to store the sum/max of the softmax per token + // The softmax_stats_d vector is used to store the max/sum of the softmax per token void* softmax_stats_d; FMHA_CHECK_CUDA(cudaMalloc(&softmax_stats_d, 2 * sizeof(float) * b * s * h)); FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h)); // The size in bytes. - size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); + const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); // Allocate on the device. void* tmp_d = nullptr; if (data_type != acc_type) @@ -1216,13 +1252,13 @@ int main(int argc, char** argv) // Allocate the reference on the host. float* o_ref_h = (float*) malloc(o_size * sizeof(float)); - float* softmax_sum_ref_h = (float*) malloc(b * s * h * sizeof(float)); - float* softmax_sum_h = (float*) malloc(b * s * h * sizeof(float)); + float* softmax_stats_ref_h = (float*) malloc(2 * b * s * h * sizeof(float)); + float* softmax_stats_h = (float*) malloc(2 * b * s * h * sizeof(float)); // The P matrix is stored as one big matrix of size S x B x H x S. - size_t const p_size = s * b * h * s; + const size_t p_size = s * b * h * s; // The size in bytes. - size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type); + const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); // Allocate on the device. void* p_d = nullptr; if (!skip_checks) @@ -1238,7 +1274,7 @@ int main(int argc, char** argv) #endif // defined(STORE_P) // The size in bytes of the S matrix (the data type may be different from P for int8). - size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type); + const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); // Allocate on the device. void* s_d = nullptr; if (!skip_checks) @@ -1327,7 +1363,7 @@ int main(int argc, char** argv) std::vector<uint32_t> seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), - [=](uint32_t const) + [=](const uint32_t) { if (fix_s) { @@ -1415,7 +1451,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes)); FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes)); - size_t const o_packed_size = cu_seqlens.back() * h * dv; + const size_t o_packed_size = cu_seqlens.back() * h * dv; // Allocate on the host. float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float)); void* o_packed_d = nullptr; @@ -1676,9 +1712,9 @@ int main(int argc, char** argv) total, num_grouped_heads, sliding_window_size, chunked_attention_size, // Paged kv cache. tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d, - packed_mask_d, cu_mask_rows_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, softmax_stats_ptr, - scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, use_int8_scale_max, interleaved, - is_s_padded, has_alibi); + packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, + softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, + use_int8_scale_max, interleaved, is_s_padded, has_alibi); // total number of tokens is needed to set TMA desc on the host. launch_params.total_q_seqlen = q_seqlens[b]; @@ -1894,8 +1930,8 @@ int main(int argc, char** argv) ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, qkv_sbh3d_d, vt_d, // WAR pass in V' - mask_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, warps_m, warps_n, - has_alibi); + mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, + warps_m, warps_n, has_alibi); timer.stop(); FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaDeviceSynchronize()); @@ -1911,7 +1947,7 @@ int main(int argc, char** argv) // Read the results. FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_ref_h, o_d, o_size, data_type)); - FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_sum_ref_h, softmax_stats_d, b * s * h, DATA_TYPE_FP32)); + FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_ref_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); } // Fill-in p/s/o with garbage data. @@ -1997,7 +2033,7 @@ int main(int argc, char** argv) std::vector<float> o_ref_trans_h(o_size); FMHA_CHECK_CUDA(cuda_memcpy_d2h(o_h, o_d_view, o_view_size, output_dtype)); - FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_sum_h, softmax_stats_d, b * s * h, DATA_TYPE_FP32)); + FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); if (interleaved) { @@ -2009,7 +2045,6 @@ int main(int argc, char** argv) // Extract the last s_q tokens from the output. extract_and_transpose_output<float>( o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, b, h, dv, is_s_padded); - if (verbose) { printf("\nChecking .....: O = V * S\n"); @@ -2018,8 +2053,8 @@ int main(int argc, char** argv) dv, epsilon, verbose, true); if (save_softmax) { - int errors = check_softmax_results(softmax_sum_h, softmax_sum_ref_h, b, s, h, seqlens, cu_seqlens); - status = status | (errors > 0); + auto errors = check_softmax_results(softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, cu_seqlens); + status = status | ((errors.first + errors.second) > 0); } } if (status != 0) @@ -2114,8 +2149,8 @@ int main(int argc, char** argv) free(s_h); free(o_h); free(o_ref_h); - free(softmax_sum_h); - free(softmax_sum_ref_h); + free(softmax_stats_h); + free(softmax_stats_ref_h); free(contiguous_kv_h); free(kv_cache_ptrs_h); free(kv_cache_block_offsets_h); diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h index f77e3f14d0..4b6cfaae4a 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h @@ -192,11 +192,14 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba void* packed_mask_ptr; // The mask input's stride in the N (K-seq) dimension. int64_t packed_mask_stride_in_bytes; - // The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max + // The Softmax stats vector of layout [total_tokens_q, h, 2], including softmax_max and softmax_sum void* softmax_stats_ptr; - // The stride between rows of softmax_stats_ptr + // The stride between rows of softmax_stats_ptr, default: h * sizeof(float2) int64_t softmax_stats_stride_in_bytes; + // The attention sinks (per head). + float* attention_sinks; + // array of length b+1 holding prefix sum of actual q sequence lengths. int* cu_q_seqlens; // array of length b+1 holding prefix sum of actual kv sequence lengths. diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h index 76670971e5..bacb4938cf 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h @@ -87,6 +87,8 @@ struct Fused_multihead_attention_params_v2 fmha::Kv_block_array paged_kv_cache; // The mask to implement drop-out. void* packed_mask_ptr; + // The attention sinks (per head). + float* attention_sinks; // The O matrix (output). void* o_ptr; // The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h index 245adc65a8..cd23452eaf 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h @@ -1294,27 +1294,47 @@ static void print_tensor( //////////////////////////////////////////////////////////////////////////////////////////////////// -static int check_softmax_results(float const* out, float const* ref, size_t b, size_t s, size_t h, +static std::pair<int, int> check_softmax_results(float const* out, float const* ref, size_t b, size_t s, size_t h, std::vector<uint32_t>& seqlens, std::vector<int>& cu_seqlens) { - int n_errors = 0; + int n_errors_max = 0; + int n_errors_sum = 0; + + // Check the max for (int b_ = 0; b_ < b; ++b_) { for (int s_ = 0; s_ < seqlens[b_]; ++s_) { for (int h_ = 0; h_ < h; ++h_) { - uint64_t idx = cu_seqlens[b_] * h + s_ * h + h_; + uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2; float sum = out[idx]; float sum_ref = ref[idx]; if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01) { - n_errors++; + n_errors_max++; } } } } - return n_errors; + // Check the sum + for (int b_ = 0; b_ < b; ++b_) + { + for (int s_ = 0; s_ < seqlens[b_]; ++s_) + { + for (int h_ = 0; h_ < h; ++h_) + { + uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2 + 1; + float sum = out[idx]; + float sum_ref = ref[idx]; + if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01) + { + n_errors_sum++; + } + } + } + } + return {n_errors_max, n_errors_sum}; } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp b/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp index 8a2e7a8fc0..6e37fc6ab4 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp +++ b/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp @@ -23,28 +23,30 @@ using Launch_params = bert::Fused_multihead_attention_launch_params; //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n, +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, + float softcapping_scale_bmm1, int warps_n, bool has_alibi); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -57,10 +59,10 @@ void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h //////////////////////////////////////////////////////////////////////////////////////////////////// -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type, +void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, float const scale_bmm1, float const scale_softmax, float const scale_bmm2, void* q_d, void* kv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, void* cu_seqlens_q_d, - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, int const runs, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, int const runs, int const warps_m, int const warps_n, bool has_alibi) { @@ -84,20 +86,22 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty // Softmax. if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); + run_softmax_fp16( + s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); + run_softmax_fp32( + s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, 0.f, - warps_n, has_alibi); + run_softmax_e4m3(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, + 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1, + run_softmax_int8(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1, scale_softmax, 0.f, warps_n, has_alibi); } else @@ -148,8 +152,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_mhca& param // types Data_type data_type, Data_type acc_type, // sizes - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, size_t const d_padded, - size_t const total, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, const size_t d_padded, + const size_t total, // device pointers void* q_packed_d, void* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d, void* p_d, void* s_d, @@ -515,17 +519,17 @@ int main(int argc, char** argv) launch_params.use_tma = use_tma; // The Q matrix of size S_Q x B x H x D. - size_t const q_size = s_q * b * h * d; + const size_t q_size = s_q * b * h * d; // The K and V matrices are packed into one big matrix of size S_KV x B x H x 2 x D. - size_t const kv_size = s_kv_padded * b * h * 2 * d; + const size_t kv_size = s_kv_padded * b * h * 2 * d; // Allocate on the host. float* q_h = (float*) malloc(q_size * sizeof(float)); // Allocate on the host. float* kv_h = (float*) malloc(kv_size * sizeof(float)); // The size in bytes. - size_t const q_size_in_bytes = get_size_in_bytes(q_size, data_type); + const size_t q_size_in_bytes = get_size_in_bytes(q_size, data_type); // The size in bytes. - size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); + const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); // Allocate on the device. void* q_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&q_d, q_size_in_bytes)); @@ -534,11 +538,11 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&kv_d, kv_size_in_bytes)); // The mask for dropout. - size_t const mask_size = s_q * b * s_kv_padded; + const size_t mask_size = s_q * b * s_kv_padded; // Allocate on the host. float* mask_h = (float*) malloc(mask_size * sizeof(float)); // The size in bytes. - size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); + const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); // Allocate on the device. void* mask_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes)); @@ -554,28 +558,28 @@ int main(int argc, char** argv) v1 ? 1 : 2); // The number of threads per CTA. - size_t const threads_per_cta = warps_m * warps_n * warps_k * 32; + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. - size_t const mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m); + const size_t mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m); // The number of mmas in the N dimension. - size_t const mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n); + const size_t mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n); // We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask). assert(!v1 || mmas_n <= 4); // The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA. - size_t const packed_mask_size = b * mmas_m * threads_per_cta; + const size_t packed_mask_size = b * mmas_m * threads_per_cta; // The size in bytes. - size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); + const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); // Allocate on the host. uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes); // Allocate on the device. void* packed_mask_d = nullptr; // The O matrix is packed as S_Q * B * H * D. - size_t const o_size = s_q * b * h * d; + const size_t o_size = s_q * b * h * d; // Allocate on the host. float* o_h = (float*) malloc(o_size * sizeof(float)); // The size in bytes. - size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type); + const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); // Allocate on the device. void* o_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); @@ -587,7 +591,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMemset(softmax_max_d, 0x00, sizeof(float) * b * s_q * h)); // The size in bytes. - size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); + const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); // Allocate on the device. void* tmp_d = nullptr; if (data_type != acc_type) @@ -599,9 +603,9 @@ int main(int argc, char** argv) float* o_ref_h = (float*) malloc(o_size * sizeof(float)); // The P matrix is stored as one big matrix of size S_Q x B x H x S_KV. - size_t const p_size = s_q * b * h * s_kv_padded; + const size_t p_size = s_q * b * h * s_kv_padded; // The size in bytes. - size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type); + const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); // Allocate on the device. void* p_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes)); @@ -614,7 +618,7 @@ int main(int argc, char** argv) #endif // defined(STORE_P) // The size in bytes of the S matrix (the data type may be different from P for int8). - size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type); + const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); // Allocate on the device. void* s_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes)); @@ -634,9 +638,9 @@ int main(int argc, char** argv) // WAR fOR MISSING CUBLAS FP8 NN SUPPORT. // Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V. - size_t const v_size = s_kv_padded * b * h * d; + const size_t v_size = s_kv_padded * b * h * d; // The size in bytes. - size_t const v_size_in_bytes = get_size_in_bytes(v_size, data_type); + const size_t v_size_in_bytes = get_size_in_bytes(v_size, data_type); float* vt_h = (float*) malloc(v_size * sizeof(float)); void* vt_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&vt_d, v_size_in_bytes)); @@ -676,7 +680,7 @@ int main(int argc, char** argv) = [min_s, fix_s, b](int s, std::vector<uint32_t>& seqlens, std::vector<int>& cu_seqlens, void** cu_seqlens_d) { std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), - [=](uint32_t const) + [=](const uint32_t) { if (fix_s) { @@ -728,7 +732,7 @@ int main(int argc, char** argv) void* kv_packed_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&kv_packed_d, kv_packed_size_in_bytes)); - size_t const o_packed_size = cu_seqlens_q.back() * h * d; + const size_t o_packed_size = cu_seqlens_q.back() * h * d; // Allocate on the host. float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float)); float* o_ref_packed_h = (float*) malloc(o_packed_size * sizeof(float)); diff --git a/cpp/kernels/fmha_v2/src/softmax_bf16.cu b/cpp/kernels/fmha_v2/src/softmax_bf16.cu index 5212d31717..79b681b502 100644 --- a/cpp/kernels/fmha_v2/src/softmax_bf16.cu +++ b/cpp/kernels/fmha_v2/src/softmax_bf16.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax<fmha::bf16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax<fmha::bf16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp16.cu b/cpp/kernels/fmha_v2/src/softmax_fp16.cu index 1fb68b1136..9df37605a2 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp16.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp16.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax<uint16_t, uint16_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax<uint16_t, uint16_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, + h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp32.cu b/cpp/kernels/fmha_v2/src/softmax_fp32.cu index 2b3bb6acbb..12bcd8624d 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp32.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp32.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax<fmha::fp16_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax<fmha::fp16_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp8.cu b/cpp/kernels/fmha_v2/src/softmax_fp8.cu index 0a8e5f5029..26c2f5e88d 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp8.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp8.cu @@ -12,10 +12,10 @@ #include "softmax_impl.h" -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi) { - run_softmax<fmha::e4m3_t, float>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, - scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax<fmha::e4m3_t, float>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_impl.h b/cpp/kernels/fmha_v2/src/softmax_impl.h index 2bc9f3380b..c26a18384a 100644 --- a/cpp/kernels/fmha_v2/src/softmax_impl.h +++ b/cpp/kernels/fmha_v2/src/softmax_impl.h @@ -10,6 +10,7 @@ * its affiliates is strictly prohibited. */ +#include <cfloat> #include <cstdio> #include <fmha/numeric_types.h> #include <fmha/utils.h> @@ -33,6 +34,8 @@ struct Softmax_params Src_type const* src; // Masks. int8_t const* mask; + // Attention sinks (per head). + float const* attention_sinks; // Softmax sum pointer. float* softmax_sum; // ALiBi @@ -148,7 +151,8 @@ static inline __device__ float apply_exp_(float x, float max) //////////////////////////////////////////////////////////////////////////////////////////////////// template <int N> -static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32) +static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32, + float& max_fp32, float const attention_sink) { // Apply the masks. @@ -159,7 +163,6 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma } // Compute the max inside the thread. - float max_fp32 = -HUGE_VALF; #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -233,7 +236,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -244,7 +247,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma //////////////////////////////////////////////////////////////////////////////////////////////////// template <int N> -static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32) +static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32, + float& max_fp32, float const attention_sink) { // Apply the masks. #pragma unroll @@ -255,7 +259,6 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma } // Compute the max inside the thread. - float max_fp32 = -HUGE_VALF; #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -401,7 +404,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -413,7 +416,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma //////////////////////////////////////////////////////////////////////////////////////////////////// template <int N> -static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32) +static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32, + float& max_fp32, float const attention_sink) { // Apply the masks. @@ -427,7 +431,6 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma } // Compute the max inside the thread. - float max_fp32 = -HUGE_VALF; #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -824,7 +827,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -994,13 +997,26 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params) } } + // The attention sink value. + float attention_sink = -FLT_MAX; + if (params.attention_sinks != nullptr) + { + attention_sink = params.attention_sinks[hi]; + } + // Do the reduction. float sum_fp32 = 0.f; - reduce(data_fp32, mask_, params.warps_n, sum_fp32); + float max_fp32 = -HUGE_VALF; + reduce(data_fp32, mask_, params.warps_n, sum_fp32, max_fp32, attention_sink); if (threadIdx.x == 0) { int sum_s = params.cu_q_seqlens[bi]; - params.softmax_sum[sum_s * params.h + si * params.h + hi] = sum_fp32; + // [B, S, H, 2] {max, sum} float + if (hi < params.h) + { + params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2] = max_fp32; + params.softmax_sum[(sum_s + si) * params.h * 2 + hi * 2 + 1] = sum_fp32; + } } // Reconvert to half. DstX_type data_dst[VECs_PER_THREAD]; @@ -1025,9 +1041,9 @@ static __global__ void softmax_kernel(Softmax_params<Dst_type, Src_type> params) //////////////////////////////////////////////////////////////////////////////////////////////////// template <typename Dst_type, typename Src_type> -void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum, void* cu_q_seqlens, int s_inner, - int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) +void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum, + void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, + float softcapping_scale_bmm1, int warps_n, bool has_alibi) { Softmax_params<Dst_type, Src_type> params; @@ -1039,6 +1055,7 @@ void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum params.softmax_sum = reinterpret_cast<float*>(softmax_sum); params.cu_q_seqlens = reinterpret_cast<int*>(cu_q_seqlens); params.mask = reinterpret_cast<int8_t const*>(mask); + params.attention_sinks = reinterpret_cast<float const*>(attention_sinks); params.has_alibi = has_alibi; // The dimensions and precomputed values. diff --git a/cpp/kernels/fmha_v2/src/softmax_int8.cu b/cpp/kernels/fmha_v2/src/softmax_int8.cu index 772fe1520c..28701de978 100644 --- a/cpp/kernels/fmha_v2/src/softmax_int8.cu +++ b/cpp/kernels/fmha_v2/src/softmax_int8.cu @@ -12,10 +12,10 @@ #include "softmax_impl.h" -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, - int warps_n, bool has_alibi) +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, + float softcapping_scale_bmm1, int warps_n, bool has_alibi) { - run_softmax<int8_t, int32_t>(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, scale_bmm1, - scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax<int8_t, int32_t>(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, + scale_bmm1, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/xqa/mha.cu b/cpp/kernels/xqa/mha.cu index c9690cbc6b..69d93e901c 100644 --- a/cpp/kernels/xqa/mha.cu +++ b/cpp/kernels/xqa/mha.cu @@ -1379,6 +1379,19 @@ __device__ inline ThrdRegRowMax mergeRowMax( return mergedRowMax; } +__device__ inline void addAttentionSinks( + ThrdRegRowMax& globalRowSum, ThrdRegRowMax const globalRowMax, float const* attentionSinks) +{ + for (uint32_t i = 0; i < globalRowSum.size; i++) + { + uint32_t srcOffset = warp_size * i + laneId(); + if (srcOffset < headGrpSize) + { + globalRowSum[i] += expf(attentionSinks[srcOffset] - globalRowMax[i]); + } + } +} + #ifdef NDEBUG __device__ __forceinline__ #else @@ -1405,6 +1418,7 @@ CUBIN_EXPORT __global__ #if SPEC_DEC MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32)]. #endif + float const* attentionSinks, // [headGrpSize] #ifdef NDEBUG KVCacheList<usePagedKVCache> const& cacheList, #if BEAM_WIDTH > 1 @@ -2371,6 +2385,12 @@ CUBIN_EXPORT __global__ float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F); if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN. + // The attention sinks are moved to the multi-block reduction part if the multi-block is enabled. + if (!isMultiBlock && attentionSinks != nullptr) + { + // Attention sinks are per head. + addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum); #if LOW_PREC_OUTPUT voScale *= rcpOutScale[0]; @@ -2559,6 +2579,11 @@ CUBIN_EXPORT __global__ assert(std::isfinite(mergedRowSum[0])); } } + if (attentionSinks != nullptr) + { + // Attention sinks are per head. + addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } __syncthreads(); rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum)); GemmOutRegTile const mergedOutTile = toFp16(sumAcc); @@ -2615,6 +2640,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32))] uint2 (each bit represents mask for one col // position). #endif + float const* attentionSinks, // [headGrpSize] KVCacheList<usePagedKVCache> const cacheList, #if BEAM_WIDTH > 1 BeamSearchParams const beamSearchParams, @@ -2640,7 +2666,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif @@ -2667,6 +2693,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -2760,7 +2787,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif @@ -2788,7 +2815,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif diff --git a/cpp/kernels/xqa/mha.h b/cpp/kernels/xqa/mha.h index 39c94f985e..d35ad48104 100644 --- a/cpp/kernels/xqa/mha.h +++ b/cpp/kernels/xqa/mha.h @@ -101,6 +101,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -140,6 +141,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index 88d4c75e30..9a438df9a2 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -428,6 +428,7 @@ __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); __device__ void storeGemm0AccToShm( uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); __device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); +__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound); #else __device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src); __device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); @@ -453,7 +454,8 @@ __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec template <bool dstIsStrided = false, typename DstHead> __device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, uint32_t nbKHeads = 0 /* only for final result in spec dec. */); + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, + uint32_t nbKHeads = 0 /* only for final result in spec dec. */); #else __device__ void transposeVTile( uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src); @@ -651,6 +653,7 @@ CUBIN_EXPORT __global__ #else IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], #endif + float const* attentionSinks, // [headGrpSize] KVCacheList<usePagedKVCache> const cacheList, #if USE_BEAM_SEARCH BeamSearchParams const beamSearchParams, @@ -1252,7 +1255,7 @@ CUBIN_EXPORT __global__ IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast<IOHead>(); #if SWAP_AB finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, - smem.gemm1WarpGrpBar, smem.gemm1AccColSum); + smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, 1, ctaNbValidTokens); @@ -1262,9 +1265,16 @@ CUBIN_EXPORT __global__ { uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); OutputHead* const dst = &output[outOffset]; + ShmQWiseVec const* attentionSinksVec = nullptr; + if (attentionSinks != nullptr) + { + attentionSinksVec + = reinterpret_cast<ShmQWiseVec const*>(attentionSinks + headGrpSize * idxHeadGrp); + } #if SWAP_AB finalizeAndWriteOut_sync<SPEC_DEC>(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, - xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, nbKHeads); + xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec, + nbKHeads); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); @@ -1585,6 +1595,17 @@ CUBIN_EXPORT __global__ } unused(bar.consumed.arrive()); } + // Add the attention sinks. + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headsPerWarp; i++) + { + uint32_t const idxHead = wid + nbMathWarps * i; + float sink = expf( + attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max); + states[i].sum += sink; + } + } __syncthreads(); uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); auto const dst = &output[outOffset]; @@ -2029,6 +2050,22 @@ __device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smem return ret; } +__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound) +{ + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) + { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast< + Vec<Vec<float, GmmaAccCoreMat::cols>, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + } + return ret; +} + __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd) { uint32_t const idxInQuad = laneId() % 4; @@ -2878,12 +2915,19 @@ __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRa template <bool dstIsStrided, typename DstHead> __device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, uint32_t nbKHeads) + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) { // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp // static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of // mufu.rcp"); - auto const regColSum = loadShmColWiseVecWithDup(accColSum); + auto regColSum = loadShmColWiseVecWithDup(accColSum); + if (attentionSinksVec != nullptr) + { + auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); + auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); + auto regColSinks = expf(regAttentionSinks - regAccColMax); + regColSum = regColSum + regColSinks; + } auto const regOutScale = __frcp_rn(regColSum) * xvoScale; rescaleAcc(acc, regOutScale); @@ -3175,6 +3219,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -3286,7 +3331,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else q, #endif - cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif @@ -3322,7 +3367,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else q, #endif - cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif diff --git a/cpp/kernels/xqa/mla_sm120.cu b/cpp/kernels/xqa/mla_sm120.cu index 74877512a7..072908fe3e 100644 --- a/cpp/kernels/xqa/mla_sm120.cu +++ b/cpp/kernels/xqa/mla_sm120.cu @@ -1859,12 +1859,13 @@ CUtensorMap makeTensorMapForQ( #endif // IS_MLA void launchMLA(cudaDeviceProp const& prop, - uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed + uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, + float* attentionSinks, // [headGrpSize], not supported. #if USE_PAGED_KV_CACHE - GMemCacheHead* pool, // global pool of pages + GMemCacheHead* pool, // global pool of pages KVCachePageIndex const* - kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] + kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] #else GMemKVCacheHead* kvCacheData, #endif diff --git a/cpp/kernels/xqa/test/refAttention.cpp b/cpp/kernels/xqa/test/refAttention.cpp index d8f1a688f5..dd356c101c 100644 --- a/cpp/kernels/xqa/test/refAttention.cpp +++ b/cpp/kernels/xqa/test/refAttention.cpp @@ -45,7 +45,7 @@ using Vector = Matrix<Type, Size, 1>; template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch> Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize) + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) { uint32_t const nbTiles = divUp(seqLen, tileSize); auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval(); @@ -113,6 +113,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt } rowSum += tileRowSum; } + + // Add the attention sinks. + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headGrpSize; i++) + { + rowSum[i] += expf(attentionSinks[i] - rowMax[i]); + } + } + Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out = gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array()); std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); }); @@ -123,7 +133,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \ refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \ CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \ - float qScale, float kvScale, float xScale, uint32_t slidingWinSize) + float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) INSTANTIATE_refFlashAttention(CacheElem, 64, false, false); INSTANTIATE_refFlashAttention(CacheElem, 64, false, true); @@ -143,7 +153,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti #else Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize) + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) { #endif float const rcpXScale = 1.f / xScale; @@ -184,7 +194,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti Eigen::Matrix<float, headGrpSize, Eigen::Dynamic, Eigen::RowMajor> x = (gemm0Acc.colwise() - rowMax).array().exp().eval(); - Eigen::Vector<float, headGrpSize> const rowSum = x.rowwise().sum().eval(); + Eigen::Vector<float, headGrpSize> rowSum = x.rowwise().sum().eval(); std::for_each(x.data(), x.data() + x.size(), [&](float& e) { e = float(MathElem(e * rcpXScale)); }); @@ -200,6 +210,18 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti } } } + + // Add the attention sinks. +#if !SPEC_DEC + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headGrpSize; i++) + { + rowSum[i] += expf(attentionSinks[i] - rowMax[i]); + } + } +#endif + Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> out = gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array()); std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); }); @@ -217,7 +239,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \ refAttention<prec, isPaged, useBeamSearch>(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, \ CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, float kvScale, float xScale, \ - uint32_t slidingWinSize) + uint32_t slidingWinSize, float* attentionSinks) #endif INSTANTIATE_refAttention(InputElem, false, false); INSTANTIATE_refAttention(InputElem, false, true); diff --git a/cpp/kernels/xqa/test/refAttention.h b/cpp/kernels/xqa/test/refAttention.h index bfab141829..a073ed0e80 100644 --- a/cpp/kernels/xqa/test/refAttention.h +++ b/cpp/kernels/xqa/test/refAttention.h @@ -83,7 +83,7 @@ struct CacheSeq<true, true> template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch> Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize); + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks); template <typename MathElem, bool isPaged, bool useBeamSearch> #if SPEC_DEC @@ -93,7 +93,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttenti #else Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q, CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize); + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks); #endif template <uint32_t ropeStyle> diff --git a/cpp/kernels/xqa/test/test.cpp b/cpp/kernels/xqa/test/test.cpp index b922857862..91b35f3e1a 100644 --- a/cpp/kernels/xqa/test/test.cpp +++ b/cpp/kernels/xqa/test/test.cpp @@ -130,7 +130,7 @@ template <uint32_t nbKHeads> #endif #endif void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false, - bool saveData = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30) + bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30) { #if IS_MLA if (nbKHeads != 1) @@ -613,6 +613,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, } } + // Allocate the attention sinks (per head) + auto attentionSinks = ManagedMemBuf<float>(nbQHeads); + // The attention sinks ptr. + float* attentionSinksPtr = hasAttentionSinks ? reinterpret_cast<float*>(attentionSinks.get()) : nullptr; + // Initialize the attention sinks (use large values to detect the potential bugs). + for (uint32_t i = 0; i < nbQHeads; i++) + { + // Range: [2, 5] + attentionSinks.get()[i] = 2.f + float(i % 4); + } + if (verbose) { printf("migrating data to gpu\n"); @@ -640,6 +651,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, #if BEAM_WIDTH > 1 cacheIndir.prefetch(dev, stream); #endif + attentionSinks.prefetch(dev, stream); }; prefetchToDevice(device); checkCuda(cudaMemsetAsync(semaphores.get(), 0, 4 * nbSemaphores, stream)); @@ -720,6 +732,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, &qHeads[0][0][0], #endif #endif + attentionSinksPtr, #if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE cacheKHeads.get(), cacheVHeads.get(), #else @@ -1028,10 +1041,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, hostMask, qSeqLen, q_len); #else Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refOutput; + auto const refAttentionSinks + = hasAttentionSinks ? attentionSinksPtr + headGrpSize * idxKHead : nullptr; if (useQGMMA) { refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, - vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); + vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, + refAttentionSinks); // refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, // vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); } @@ -1039,8 +1055,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, { // refOutput = refFlashAttention<InputElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], // kCacheSeq, vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale); - refOutput = refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, - vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); + refOutput + = refAttention<InputElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, vCacheSeq, + seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks); } #endif if (lowPrecOutput) @@ -1196,11 +1213,23 @@ TEST(RefCheck, llama_V2_70b) runTest<2>(2, 514, false, true); runTest<1>(1, 4096, false, true); #if SLIDING_WINDOW - runTest<2>(2, 4096, false, true, false, false, ~0, 256); - runTest<2>(2, 400, false, true, false, false, ~0U, 256); + runTest<2>(2, 4096, false, true, false, false, false, ~0, 256); + runTest<2>(2, 400, false, true, false, false, false, ~0U, 256); #endif runTest<8>(120, 367, false, true); - // runTest<8>(1792, 2048, false, true); + runTest<8>(1792, 2048, false, true); +} + +TEST(RefCheck, attention_sinks) +{ + auto runAttentionSinksTest = [](uint32_t batchSize, uint32_t seqLen) + { runTest<8>(batchSize, seqLen, false, true, false, false, /*hasAttentionSinks*/ true); }; + + runAttentionSinksTest(2, 2); + runAttentionSinksTest(2, 15); + runAttentionSinksTest(2, 256); + runAttentionSinksTest(2, 514); + runAttentionSinksTest(1, 4096); } TEST(Perf, tracing_long) @@ -1264,7 +1293,7 @@ TEST(Perf, mlperf_gptj) #ifndef NDEBUG GTEST_SKIP() << "Skipping perf tests for debug build"; #endif - runTest<32>(396, 800 + 224, true, false, false, false, 800); + runTest<32>(396, 800 + 224, true, false, false, false, false, 800); } TEST(Perf, mlperf_llama) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 565c170e1d..2559ae5484 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -53,6 +53,7 @@ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner; using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; +using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams; using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation; static BufferManager::CudaStreamPtr streamPtr; @@ -980,11 +981,11 @@ public: auto stream = streamPtr->get(); MoeMinLatencyParams min_latency_params; #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, + mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true, mSelectedExperts + mSelectedExpertsSize * mBufferIndex, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, - mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex, + ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex, mFinalOutput + mFinalOutputSize * mBufferIndex, @@ -992,11 +993,11 @@ public: /*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex], /*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream); #else - mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, + mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true, mSelectedExperts + mSelectedExpertsSize * mBufferIndex, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, - mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex, + ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex, mFinalOutput + mFinalOutputSize * mBufferIndex, diff --git a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp index d25572ad5a..514d100fe5 100644 --- a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp +++ b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp @@ -30,6 +30,11 @@ void tensorrt_llm::batch_manager::AssignReqSeqSlots::operator()(SequenceSlotMana { for (auto const& llmReq : requests) { + if (llmReq->isDisaggGenerationInitState()) + { + // Skip assigning sequence slot for DISAGG_GENERATION_INIT request + continue; + } auto const isReqNew = (llmReq->isContextInitState() && !llmReq->mSeqSlot) || (llmReq->isDisaggGenerationTransmissionComplete()); if (isReqNew && llmReq->getReturnPerfMetrics()) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 2edfd5f77a..503c2e6c5d 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -360,7 +360,7 @@ void CacheFormatter::format(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (connections.size() > 1) @@ -712,7 +712,7 @@ void CacheFormatter::unformat(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (pickUpConnections.size() > 1) { @@ -846,16 +846,23 @@ void CacheFormatter::unformat(TransferSession& session) } int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism; + int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; + int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + + if (selfPPSize == destPPSize) + { + return true; + } if (selfNumLayers % selfPPSize != 0) { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers must be divisible by pipeline parallelism"); + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d", + selfNumLayers, selfPPSize); return false; } - int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); - int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; if (destNumLayers % destPPSize != 0) { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers must be divisible by pipeline parallelism"); + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ", + destNumLayers, destPPSize); return false; } return true; diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index ee199c2fb1..8ae8ee5f2c 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -76,15 +76,6 @@ public: /// @brief Destructor. virtual ~BaseCacheFormatter() = default; - - // TODO: better way for context/generation tagging - void markAsSender(bool isSender) - { - kvCacheMeasureHelper.markAsSender(isSender); - } - -protected: - KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()}; }; // Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 93df2f96ec..16771709bb 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -44,14 +44,16 @@ namespace tensorrt_llm::batch_manager using SizeType32 = CreateNewDecoderRequests::SizeType32; using TensorPtr = CreateNewDecoderRequests::TensorPtr; +using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr; namespace { void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffers& inputBuffers, - ITensor& sequenceLengths, SizeType32 beamWidth, runtime::BufferManager const& manager, - runtime::CudaStream const& stream) + ITensor& sequenceLengths, SizeType32 beamWidth, runtime::CudaStream const& stream) { + auto const bufferManager = BufferManager{std::make_shared<CudaStream>(stream.get())}; + auto const batchSize = contextRequests.size(); auto batchSlotsView = tr::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize); auto fillValuesView = tr::ITensor::slice(inputBuffers.fillValues, 0, batchSize); @@ -79,8 +81,8 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe auto batchSlotsDeviceView = tr::ITensor::slice(inputBuffers.setupBatchSlotsDevice, 0, batchSize); auto fillValuesViewDevice = tr::ITensor::slice(inputBuffers.fillValuesDevice, 0, batchSize); - manager.copy(*batchSlotsView, *batchSlotsDeviceView); - manager.copy(*fillValuesView, *fillValuesViewDevice); + bufferManager.copy(*batchSlotsView, *batchSlotsDeviceView); + bufferManager.copy(*fillValuesView, *fillValuesViewDevice); tr::kernels::invokeFillBatch(sequenceLengths, *batchSlotsDeviceView, beamWidth, *fillValuesViewDevice, stream); } } @@ -127,10 +129,10 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>> CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, - executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, - runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const + executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, nvinfer1::DataType logitsType, + DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, + CudaStream const& decoderStream, SizeType32 maxSequenceLength, SizeType32 beamWidth, + OptionalRef<MedusaBuffers const> medusaBuffers) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(CreateNewDecoderRequests); @@ -141,13 +143,13 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru if (!finishedContextRequests.empty()) { - copySequenceLengths(finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, - bufferManager, runtimeStream); + copySequenceLengths( + finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, runtimeStream); } - auto [lookaheadPrompt, lookaheadAlgoConfigs] = createDecoderRequests(finishedContextRequests, - inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig, - runtimeStream, decoderStream, maxSequenceLength, medusaBuffers); + auto [lookaheadPrompt, lookaheadAlgoConfigs] + = createDecoderRequests(finishedContextRequests, inputBuffers.inputsIds, decodingConfig, decoderState, + logitsType, modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, medusaBuffers); auto const batchSize = finishedContextRequests.size(); @@ -165,115 +167,122 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru std::move(lookaheadAlgoConfigs)}; } -void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength) +namespace { - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - TLLM_CHECK(batchSlot >= 0); - - BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())}; - - auto const batchSize = decoderState.getMaxBatchSize(); - TLLM_CHECK(0 <= batchSize && batchSlot < batchSize); - auto const maxBeamWidth = decoderState.getMaxBeamWidth(); - auto const beamWidth = samplingConfig.beamWidth; - TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth, - tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.", - beamWidth, maxBeamWidth)); - auto const& requestIds = request.ids; - auto const inputLength = request.inputLen; - auto const numDecodingEngineTokens = request.generatedTokensPerEngineStep; +void initializeInputLengths(DecodingInput& dJointInput, SizeType32 batchSlot, SizeType32 inputLength, + std::optional<SizeType32> maxNewTokensOpt, SizeType32 numDecodingEngineTokens, SizeType32 maxSequenceLength, + BufferManager const& manager) +{ auto const numDecodingDraftEngineTokens = numDecodingEngineTokens - 1; - auto const maxNewTokens - = request.maxNewTokens.value_or(maxSequenceLength - inputLength - numDecodingDraftEngineTokens); + auto const maxNewTokens = maxNewTokensOpt.value_or(maxSequenceLength - inputLength - numDecodingDraftEngineTokens); TLLM_CHECK_WITH_INFO(inputLength + maxNewTokens + numDecodingDraftEngineTokens <= maxSequenceLength, tc::fmtstr( "Input length (%d) + max new tokens (%d) + draft tokens (%d) must be less than max sequence length (%d).", inputLength, maxNewTokens, numDecodingDraftEngineTokens, maxSequenceLength)); - TLLM_CHECK(requestIds->getDataType() == TRTDataType<TokenIdType>::value); - auto const endId = request.endId.value_or(-1); - // input - auto& dJointInput = decoderState.getJointDecodingInput(); + TensorPtr const sequenceLimitLength{ + ITensor::slice(constPointerCast(dJointInput.sequenceLimitLength), batchSlot, 1)}; + runtime::kernels::invokeFill(*sequenceLimitLength, inputLength + maxNewTokens, manager.getStream()); - dJointInput.beamWidths.at(batchSlot) = beamWidth; - decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens); + TensorPtr const inputLengths{ITensor::slice(constPointerCast(dJointInput.lengths), batchSlot, 1)}; + runtime::kernels::invokeFill(*inputLengths, inputLength, manager.getStream()); +} +void initializeRequestIds(DecodingInput& dJointInput, DecodingOutput& dJointOutput, SizeType32 batchSlot, + SharedConstPtr const& requestIds, SizeType32 endId, SizeType32 beamWidth, SizeType32 maxSequenceLength, + BufferManager const& manager) +{ TensorPtr const endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchSlot, 1)}; - runtime::kernels::invokeFill(*endIdTensorPtr, endId, decoderStream); + runtime::kernels::invokeFill(*endIdTensorPtr, endId, manager.getStream()); + // fill outputIds with endIds + TensorPtr const outputIds = ITensor::slice(dJointOutput.ids, batchSlot, 1); + auto outputIdsTileView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, maxSequenceLength})); + runtime::kernels::invokeFill(*outputIdsTileView, endId, manager.getStream()); + + // copy the request ids into outputIds + auto const requestIdsShape = requestIds->getShape(); + auto outputIdsView = ITensor::view(outputIds, requestIdsShape); + manager.copy(*requestIds, *outputIdsView); +} + +void initializeBeamSearch(DecodingInput& dJointInput, DecodingOutput& dJointOutput, SizeType32 batchSlot, + SizeType32 endId, SizeType32 beamWidth, SizeType32 maxSequenceLength, BufferManager const& manager) +{ + TensorPtr const cumLogProbs = ITensor::slice(dJointOutput.cumLogProbs, batchSlot, 1); + runtime::kernels::invokeFill( + *IBuffer::slice(cumLogProbs, 1, beamWidth - 1), DecodingOutput::kNegativeInfinity, manager.getStream()); + + auto parentIds = ITensor::slice(dJointOutput.parentIds, batchSlot, 1); + auto const outputIdsShape = ITensor::makeShape({1, beamWidth, maxSequenceLength}); + parentIds->reshape(outputIdsShape); + manager.setZero(*parentIds); + + auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1); + manager.setZero(*cacheIndirectionInput); + + auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1); + manager.setZero(*cacheIndirectionOutput); + + auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1); + beamHypotheses.init(manager, endId); +} + +void initializeEmbeddingBias(DecodingInput& dJointInput, SizeType32 batchSlot, + std::optional<TensorPtr> const& embeddingBias, nvinfer1::DataType logitsType, + runtime::ModelConfig const& modelConfig, BufferManager const& manager) +{ TensorPtr const embeddingBiasSlice = ITensor::slice(constPointerCast(dJointInput.embeddingBias), batchSlot, 1); - if (request.embeddingBias) + if (embeddingBias.has_value()) { - TLLM_CHECK(request.embeddingBias->getShape().nbDims == 2); - TLLM_CHECK(request.embeddingBias->getShape().d[0] == 1); - TLLM_CHECK_WITH_INFO(request.embeddingBias->getShape().d[1] == modelConfig.getVocabSize(), + auto embeddingBiasTensor = getEmbeddingBias(logitsType, embeddingBias.value()); + + TLLM_CHECK(embeddingBiasTensor->getShape().nbDims == 2); + TLLM_CHECK(embeddingBiasTensor->getShape().d[0] == 1); + TLLM_CHECK_WITH_INFO(embeddingBiasTensor->getShape().d[1] == modelConfig.getVocabSize(), "The embedding bias shape is not as expected. Expected last dimension to be same as vocab size: %d.", modelConfig.getVocabSize()); - manager.copy(*request.embeddingBias, *embeddingBiasSlice); + manager.copy(*embeddingBiasTensor, *embeddingBiasSlice); } else { manager.setZero(*embeddingBiasSlice); } +} - auto setupWords = [](std::vector<runtime::ITensor::SharedPtr>& jointWordsLists, TensorPtr const& requestWordsList, - SharedConstPtr& jointWordsPtrs, SharedConstPtr& jointWordsLens, SizeType32& jointMaxWordsLen, - SizeType32 batchSlot) +void setupWords(std::vector<runtime::ITensor::SharedPtr>& jointWordsLists, + std::optional<TensorPtr> const& requestWordsList, SharedConstPtr& jointWordsPtrs, SharedConstPtr& jointWordsLens, + SizeType32& jointMaxWordsLen, SizeType32 batchSlot, BufferManager const& manager) +{ + if (requestWordsList.has_value()) { - if (requestWordsList) - { - auto const wordsLen = requestWordsList->getShape().d[1]; - BufferRange<int32_t*>(*constPointerCast(jointWordsPtrs))[batchSlot] - = runtime::bufferCast<TokenIdType>(*requestWordsList); - runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = wordsLen; - // FIXME: this is monotonically growing size - jointMaxWordsLen = std::max(static_cast<SizeType32>(wordsLen), jointMaxWordsLen); + // Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects + TensorPtr wordsList = manager.copyFrom(*requestWordsList.value(), MemoryType::kGPU); + wordsList->squeeze(0); - // NOTE: jointWordsList is not used in gptDecoder, but required to keep <name>WordsList's - // memory allocated - jointWordsLists[batchSlot] = requestWordsList; - } - else - { - runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = 0; - } - }; + auto const wordsLen = wordsList->getShape().d[1]; + BufferRange<int32_t*>(*constPointerCast(jointWordsPtrs))[batchSlot] + = runtime::bufferCast<TokenIdType>(*wordsList); + runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = wordsLen; + // FIXME: this is monotonically growing size + jointMaxWordsLen = std::max(static_cast<SizeType32>(wordsLen), jointMaxWordsLen); - setupWords(dJointInput.stopWordsLists, request.stopWordsList, dJointInput.stopWordsPtrs, dJointInput.stopWordsLens, - dJointInput.maxStopWordsLen, batchSlot); - - setupWords(dJointInput.badWordsLists, request.badWordsList, dJointInput.badWordsPtrs, dJointInput.badWordsLens, - dJointInput.maxBadWordsLen, batchSlot); - - TensorPtr const sequenceLimitLength{ - ITensor::slice(constPointerCast(dJointInput.sequenceLimitLength), batchSlot, 1)}; - runtime::kernels::invokeFill(*sequenceLimitLength, inputLength + maxNewTokens, decoderStream); - - TensorPtr const inputLengths{ITensor::slice(constPointerCast(dJointInput.lengths), batchSlot, 1)}; - runtime::kernels::invokeFill(*inputLengths, inputLength, decoderStream); - - // output - auto& dJointOutput = decoderState.getJointDecodingOutput(); - auto const outputIdsShape = ITensor::makeShape({1, beamWidth, maxSequenceLength}); - - auto finishedSum = ITensor::slice(dJointOutput.finishedSum, batchSlot, 1); - manager.setZero(*finishedSum); - - for (SizeType32 ti = 0; ti < decoderState.getMaxDecodingEngineTokens(); ++ti) - { - TensorPtr const newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, 1); - newTokensStepView->squeeze(0); - auto newTokensVec = ITensor::slice(newTokensStepView, batchSlot, 1); - manager.setZero(*newTokensVec); + // NOTE: jointWordsList is not used in gptDecoder, but required to keep <name>WordsList's + // memory allocated + jointWordsLists[batchSlot] = wordsList; } + else + { + runtime::bufferCast<SizeType32>(*constPointerCast(jointWordsLens))[batchSlot] = 0; + } +}; - TensorPtr const finishedStepsSlice = ITensor::slice(decoderState.getFinishReasons(), batchSlot, 1); - manager.setZero(*finishedStepsSlice); +void initializeLogProbs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SamplingConfig const& samplingConfig, + BufferManager const& manager) +{ + auto const beamWidth = samplingConfig.beamWidth; // cumLogProb is mandatory for beamWidth > 1 if ((samplingConfig.cumLogProbs.has_value() && samplingConfig.cumLogProbs->at(0)) || beamWidth > 1) @@ -287,49 +296,32 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder auto logProbs = ITensor::slice(dJointOutput.logProbs, batchSlot, 1); manager.setZero(*logProbs); } +} - if (beamWidth > 1) +void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeType32 maxDecodingEngineTokens, + BufferManager const& manager) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto finishedSum = ITensor::slice(dJointOutput.finishedSum, batchSlot, 1); + manager.setZero(*finishedSum); + + for (SizeType32 ti = 0; ti < maxDecodingEngineTokens; ++ti) { - TensorPtr const cumLogProbs = ITensor::slice(dJointOutput.cumLogProbs, batchSlot, 1); - runtime::kernels::invokeFill( - *IBuffer::slice(cumLogProbs, 1, beamWidth - 1), DecodingOutput::kNegativeInfinity, decoderStream); - - auto parentIds = ITensor::slice(dJointOutput.parentIds, batchSlot, 1); - parentIds->reshape(outputIdsShape); - manager.setZero(*parentIds); - - auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1); - manager.setZero(*cacheIndirectionInput); - - auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1); - manager.setZero(*cacheIndirectionOutput); - - auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1); - beamHypotheses.init(manager, endId); + TensorPtr const newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, 1); + newTokensStepView->squeeze(0); + auto newTokensVec = ITensor::slice(newTokensStepView, batchSlot, 1); + manager.setZero(*newTokensVec); } - // Speculative execution - if (numDecodingEngineTokens > 1 || decoderState.getSpeculativeDecodingMode().isDraftTokensExternal()) - { - TLLM_CHECK(beamWidth == 1); - newRequestSpeculativeDecoding(batchSlot, request, samplingConfig, modelConfig, - decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, decoderStream, - decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); - } - - // fill outputIds with endIds - TensorPtr const outputIds = ITensor::slice(dJointOutput.ids, batchSlot, 1); - auto outputIdsTileView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, maxSequenceLength})); - runtime::kernels::invokeFill(*outputIdsTileView, endId, decoderStream); - - // copy the request ids into outputIds - auto const requestIdsShape = requestIds->getShape(); - auto outputIdsView = ITensor::view(outputIds, requestIdsShape); - manager.copy(*requestIds, *outputIdsView); + TensorPtr const finishedStepsSlice = ITensor::slice(dJointOutput.finishReasons, batchSlot, 1); + manager.setZero(*finishedStepsSlice); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +} // namespace + void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, @@ -557,11 +549,12 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec std::tuple<std::vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>> CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds, executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState, - BufferManager const& bufferManager, nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, - runtime::WorldConfig const& worldConfig, runtime::CudaStream const& runtimeStream, - runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, + nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const { + auto const decoderBufferManager = BufferManager{std::make_shared<CudaStream>(decoderStream.get())}; + unsigned decoderInputSize{0}; for (auto const& llmReq : finishedContextRequests) { @@ -586,26 +579,38 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon SizeType32 inputOffset{0}; for (auto const& llmReq : finishedContextRequests) { - auto const promptLen = llmReq->getPromptLen(); - auto const& reqTokens = llmReq->getTokens(0); - TLLM_CHECK(reqTokens.size() == static_cast<decltype(reqTokens.size())>(promptLen)); - TensorPtr inputView = ITensor::slice(inputIds, inputOffset, promptLen); - bufferManager.copy(reqTokens.data(), *inputView); - - auto decoderRequest = decoder_batch::Request{inputView, promptLen, llmReq->mMaxNewTokens, llmReq->mEndId}; - llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs; + + TLLM_CHECK(llmReq->mSeqSlot.has_value()); + auto const batchSlot = llmReq->mSeqSlot.value(); + auto const batchSize = decoderState.getMaxNumSequences(); + TLLM_CHECK(0 <= batchSlot && batchSlot < batchSize); + + auto const& samplingConfig = llmReq->mSamplingConfig; + + auto const beamWidth = samplingConfig.beamWidth; + auto const maxBeamWidth = decoderState.getMaxBeamWidth(); + TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth, + tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.", + beamWidth, maxBeamWidth)); + decoderState.setBeamWidth(batchSlot, beamWidth); + + auto const promptLen = llmReq->getPromptLen(); + + auto decoderRequest = decoder_batch::Request{promptLen}; + if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()) { if (llmReq->hasDraftTokens()) { auto const& draftTokens = llmReq->getDraftTokens(); - decoderRequest.draftTokens = bufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); + // Copy to pinned host memory (don't care about stream of bufferManager) + decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); auto const& draftLogits = llmReq->getDraftLogits(); if (draftLogits.has_value()) { decoderRequest.draftLogits - = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), bufferManager); + = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager); } decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1; } @@ -618,48 +623,77 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon { decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens(); } - if (modelConfig.getSpeculativeDecodingMode().isMedusa()) - { - TLLM_CHECK(medusaBuffers); - llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; - // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? - // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. - decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); - decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); - } - else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) - { - lookaheadPrompt.emplace_back(ITensor::slice(decoderRequest.ids, 0, decoderRequest.inputLen)); - auto const& lookaheadRuntimeConfig - = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); - lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); - } - else if (modelConfig.getSpeculativeDecodingMode().isEagle()) + auto& dJointInput = decoderState.getJointDecodingInput(); + + auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep; + initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens, + maxSequenceLength, decoderBufferManager); + decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens); + + initializeEmbeddingBias( + dJointInput, batchSlot, llmReq->getEmbeddingBias(), logitsType, modelConfig, decoderBufferManager); + + setupWords(dJointInput.badWordsLists, llmReq->getBadWordsList(), dJointInput.badWordsPtrs, + dJointInput.badWordsLens, dJointInput.maxBadWordsLen, batchSlot, decoderBufferManager); + + setupWords(dJointInput.stopWordsLists, llmReq->getStopWordsList(), dJointInput.stopWordsPtrs, + dJointInput.stopWordsLens, dJointInput.maxStopWordsLen, batchSlot, decoderBufferManager); + + auto& dJointOutput = decoderState.getJointDecodingOutput(); + + initializeOutputs(dJointOutput, batchSlot, decoderState.getMaxDecodingEngineTokens(), decoderBufferManager); + + initializeLogProbs(dJointOutput, batchSlot, samplingConfig, decoderBufferManager); + + auto const& reqTokens = llmReq->getTokens(0); + TLLM_CHECK(reqTokens.size() == static_cast<decltype(reqTokens.size())>(promptLen)); + TensorPtr requestIds = ITensor::slice(inputIds, inputOffset, promptLen); + // Copy to pinned host memory (don't care about stream of bufferManager) + decoderBufferManager.copy(reqTokens.data(), *requestIds); + auto const endId = llmReq->mEndId.value_or(-1); + + initializeRequestIds(dJointInput, dJointOutput, batchSlot, requestIds, endId, beamWidth, maxSequenceLength, + decoderBufferManager); + + if (beamWidth > 1) { - decoderRequest.eagleConfig - = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); - } - if (llmReq->getEmbeddingBias().has_value()) - { - decoderRequest.embeddingBias = getEmbeddingBias(logitsType, llmReq->getEmbeddingBias().value()); - } - if (llmReq->getBadWordsList().has_value()) - { - // Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects - decoderRequest.badWordsList = bufferManager.copyFrom(*llmReq->getBadWordsList().value(), MemoryType::kGPU); - decoderRequest.badWordsList->squeeze(0); - } - if (llmReq->getStopWordsList().has_value()) - { - decoderRequest.stopWordsList - = bufferManager.copyFrom(*llmReq->getStopWordsList().value(), MemoryType::kGPU); - decoderRequest.stopWordsList->squeeze(0); + initializeBeamSearch( + dJointInput, dJointOutput, batchSlot, endId, beamWidth, maxSequenceLength, decoderBufferManager); } - TLLM_CHECK(llmReq->mSeqSlot.has_value()); - newRequest(llmReq->mSeqSlot.value(), decoderRequest, llmReq->mSamplingConfig, modelConfig, decoderState, - runtimeStream, decoderStream, maxSequenceLength); + // Speculative execution + if (!decoderState.getSpeculativeDecodingMode().isNone()) + { + TLLM_CHECK(beamWidth == 1); + + if (modelConfig.getSpeculativeDecodingMode().isMedusa()) + { + TLLM_CHECK(medusaBuffers); + llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; + // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? + // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. + decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); + decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); + } + else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) + { + lookaheadPrompt.emplace_back(requestIds); + + auto const& lookaheadRuntimeConfig + = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); + lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); + } + else if (modelConfig.getSpeculativeDecodingMode().isEagle()) + { + decoderRequest.eagleConfig + = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); + } + + newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig, + decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, + decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); + } decoderRequests.push_back(decoderRequest); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index a4617c0d53..522ec80f84 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -91,6 +91,43 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) return totalSize; } +void TransferSession::appendMeasure(double delay, double duration, size_t size) +{ + if (!mRecordMeasure) + { + return; + } + auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps + mMeasures.emplace_back(Measure{delay, duration, bandwidth}); +} + +void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const +{ + if (mMeasures.empty()) + { + return; + } + // write header if not exist + if (outFile.tellp() == 0) + { + outFile << "RequestID"; + for (size_t i = 0; i < mMeasures.size(); i++) + { + outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; + } + outFile << '\n'; + } + // write measures + TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); + auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); + outFile << reqId; + for (auto const& measure : mMeasures) + { + outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; + } + outFile << '\n' << std::flush; +} + class DataResponder::Impl { public: diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 91215ff66c..ef66cd1382 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -97,15 +97,23 @@ private: class TransferSession { public: + struct Measure + { + double delay; // from last token (ctx) or arrival time (gen), in ms + double duration; // in ms + double bandwidth; // in Gbps + }; + TransferSession(std::vector<Connection const*> connections, DataContext dataContext, executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr) + runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) : mConnections(std::move(connections)) , mDataContext(dataContext) , mSelfState(&selfState) , mOtherState(std::move(otherState)) , mBufferManager(&bufferManager) , mRequest(llmRequest) + , mRecordMeasure(recordMeasure) { TLLM_CHECK(!mConnections.empty()); } @@ -163,6 +171,11 @@ public: mRequest = &llmRequest; } + void appendMeasure(double delay, double duration, size_t size); + + // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file + void exportMeasure(std::ofstream& outFile, bool isContext) const; + private: std::vector<Connection const*> mConnections; DataContext mDataContext; @@ -170,6 +183,8 @@ private: executor::DataTransceiverState mOtherState; runtime::BufferManager const* mBufferManager; LlmRequest const* mRequest; + bool mRecordMeasure; + std::vector<Measure> mMeasures; }; // Operators required for data transmission in specific communication protocols. @@ -266,79 +281,4 @@ private: std::unique_ptr<Impl> mImpl; }; -class KvCacheMeasureHelper -{ -public: - struct Measure - { - double delay; // from last token (ctx) or arrival time (gen), in ms - double duration; // in ms - double bandwidth; // in Gbps - }; - - KvCacheMeasureHelper(std::string output_path) - : mOutputPath(std::move(output_path)) - { - } - - void markAsSender(bool isSender) - { - mIsSender = isSender; - } - - void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size) - { - auto bandwidth = size * 8 / (duration / 1000) / 1e9; - if (mOutputPath.empty()) - { - return; - } - - std::lock_guard<std::mutex> lock(mMutex); - mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth}); - } - - ~KvCacheMeasureHelper() - { - if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) - { - TLLM_CHECK(mIsSender.has_value()); - auto rank = mpi::MpiComm::world().getRank(); - std::string outFilePath - = mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv"; - std::ofstream outFile(outFilePath); - - TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); - - size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); - - outFile << "RequestID"; - for (size_t i = 0; i < numTransferMeasure; i++) - { - outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; - } - outFile << '\n'; - - for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) - { - outFile << requestID; - - for (auto const& measure : measures) - { - outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; - } - outFile << '\n'; - } - - outFile.close(); - } - } - -private: - std::map<LlmRequest::RequestIdType, std::vector<Measure>> mRequestKVCacheTranfserMeasure; - std::string mOutputPath; - std::mutex mMutex; - std::optional<bool> mIsSender; -}; - } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp index 9a72bf2d00..1a5c7fab4d 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @@ -21,6 +21,8 @@ #include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include <filesystem> + namespace tensorrt_llm::batch_manager { @@ -30,6 +32,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId) return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF); } +namespace fs = std::filesystem; + +static fs::path getTransferOutputPath(char const* tag) +{ + auto outputPath = common::getEnvKVCacheTransferOutputPath(); + if (!outputPath.empty()) + { + auto rank = mpi::MpiComm::world().getRank(); + auto path = fs::path(outputPath); + fs::create_directories(path); + return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv"); + } + return {}; +} + DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter) : mManager{manager} @@ -39,7 +56,6 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, { TLLM_CHECK(mManager); TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); - mFormatter->markAsSender(true); } [[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo() @@ -86,7 +102,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, if (it == mRequestToSession.end()) { auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr), - DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager); + DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr, + !common::getEnvKVCacheTransferOutputPath().empty()); it = mRequestToSession.emplace(requestId, std::move(session)).first; } it->second.setConnection(peerIdx, connection); @@ -125,6 +142,17 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId) auto it = mRequestToSession.find(requestId); TLLM_CHECK(it != mRequestToSession.end()); std::unique_lock<std::mutex> lk(mMtxForMap); + if (!common::getEnvKVCacheTransferOutputPath().empty()) + { + if (!mMeasuresFile.is_open()) + { + auto outputPath = getTransferOutputPath("send"); + mMeasuresFile.open(outputPath); + TLLM_CHECK_WITH_INFO( + mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); + } + it->second.exportMeasure(mMeasuresFile, true); + } mRequestToSession.erase(it); } @@ -137,7 +165,6 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage TLLM_CHECK(mManager); TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); TLLM_CHECK(mFormatter); - mFormatter->markAsSender(false); } TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) @@ -203,12 +230,24 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) } auto const& resource = getReceiveCacheResource(llmRequest); return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, &llmRequest); + contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); } void DataReceiverImpl::receiveSync(TransferSession& session) { mFormatter->unformat(session); + if (!common::getEnvKVCacheTransferOutputPath().empty()) + { + std::unique_lock<std::mutex> lock(mMeasuresFileMutex); + if (!mMeasuresFile.is_open()) + { + auto outputPath = getTransferOutputPath("recv"); + mMeasuresFile.open(outputPath); + TLLM_CHECK_WITH_INFO( + mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); + } + session.exportMeasure(mMeasuresFile, false); + } } void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h index fa8d272832..2f277f14ff 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @@ -23,6 +23,8 @@ #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" +#include <fstream> + namespace tensorrt_llm::batch_manager { struct TransceiverTag @@ -67,6 +69,7 @@ private: std::unique_ptr<BaseCacheFormatter> mFormatter; std::mutex mMtxForMap; runtime::BufferManager mBufferManager; + std::ofstream mMeasuresFile; }; class DataReceiverImpl : public DataReceiver, public TransceiverTag @@ -103,6 +106,8 @@ private: std::unique_ptr<BaseCacheFormatter> mFormatter; std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources; std::mutex mProcessIoResouceMutex; + std::ofstream mMeasuresFile; + std::mutex mMeasuresFileMutex; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index 040dcd147e..ea5f098107 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests) continue; } auto const seqSlot = llmReq->mSeqSlot.value(); - if (llmReq->isContextInitState() - && llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen()) + if (llmReq->isContextInitState() && llmReq->isFirstContextChunk()) { // The request is in the first context forward step (considering kv cache reuse). auto const& guideType = guidedDecodingParams->getGuideType(); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp index ff2a2f6b78..ac37278d45 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp @@ -18,20 +18,51 @@ #include "tensorrt_llm/batch_manager/kvCacheEventManager.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/serialization.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" namespace tle = tensorrt_llm::executor; namespace tensorrt_llm::batch_manager::kv_cache_manager { -KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries) +KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank, + std::optional<SizeType32> attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs) : mRun{true} , mMaxSize{maxKVEventEntries} , mEventId{0} + , mAttentionDpRank{attentionDpRank} + , mAttentionDpSize{attentionDpSize} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { TLLM_CHECK(mMaxSize > 0); - // mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this)); + if (mAttentionDpRank) + { + TLLM_CHECK_WITH_INFO( + mAttentionDpSize.has_value(), "If attention DP rank is set, the attention DP size must also be set"); + TLLM_CHECK_WITH_INFO(mAttentionDpRank.value() < mAttentionDpSize.value(), + "Attention DP rank must be less than attention DP size"); + if (mAttentionDpRank.value() == 0) + { + // Rank 0 will gather events from all other ranks + // Need to increase size + mMaxSize *= mAttentionDpSize.value(); + } + // Create a communicator to be used for event exchange + mMpiComm = std::make_unique<tensorrt_llm::mpi::MpiComm>(COMM_SESSION.split(0, mAttentionDpRank.value())); + } + else + { + TLLM_CHECK_WITH_INFO( + !mAttentionDpSize.has_value(), "If attention DP rank is not set, the attention DP size must not be set"); + } mWorkerThread = std::thread([this]() { this->worker(); }); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpThread(); }); + } +#endif }; KVCacheEventManager::~KVCacheEventManager() @@ -40,12 +71,18 @@ KVCacheEventManager::~KVCacheEventManager() mPendingEmptyCV.notify_all(); mEmptyCV.notify_all(); mWorkerThread.join(); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread.join(); + } +#endif } void KVCacheEventManager::enqueueCreatedEvent( std::vector<SizeType32> const& numBlocksPerCacheLevel, SizeType32 windowSize) { - enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks, SizeType32 windowSize) @@ -68,7 +105,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority()); } - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize) @@ -81,13 +118,13 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 } else { - enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank}); } } void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize) { - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event) @@ -120,8 +157,76 @@ void KVCacheEventManager::flush() mPendingEmptyCV.notify_one(); } +void KVCacheEventManager::exchangeAttentionDpThread() +{ +#if ENABLE_MULTI_DEVICE + while (true) + { + TLLM_CHECK(mAttentionDpRank); + + // Check if any of the ranks have been shutdown + int32_t numFinished = 0; + int32_t finished = mRun ? 0 : 1; + mMpiComm->allreduce(&finished, &numFinished, 1, mpi::MpiType::kINT32, mpi::MpiOp::SUM); + if (numFinished > 0) + { + TLLM_LOG_INFO("One of the rank has been shut down, exiting"); + break; + } + + // If we are not rank 0, send events to rank 0 + if (mAttentionDpRank.value() != 0) + { + std::vector<char> serializedEvents; + uint64_t numEvents = 0; + { + std::lock_guard<std::mutex> lck(mEventsMutex); + serializedEvents = executor::Serialization::serialize(mEvents); + numEvents = mEvents.size(); + mEvents.clear(); + } + uint64_t vecSize = numEvents > 0 ? serializedEvents.size() : 0; + mMpiComm->send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + mMpiComm->send(serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0, + mpi::MpiTag::kKvCacheEvent); + } + } + else + { + TLLM_CHECK(mAttentionDpSize.has_value()); + // Loop until have received events from all ranks + for (int rank = 1; rank < mAttentionDpSize.value(); ++rank) + { + uint64_t vecSize{0}; + mMpiComm->recv(&vecSize, 1, mpi::MpiType::kUINT64, rank, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + std::vector<char> serializedEvents(vecSize); + mMpiComm->recv( + serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, rank, mpi::MpiTag::kKvCacheEvent); + + // Deserialize the events and add them to the local queue + auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents); + { + std::lock_guard<std::mutex> lck(mEventsMutex); + mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end()); + mEmptyCV.notify_one(); + } + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs)); + } +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + void KVCacheEventManager::worker() { + while (true) { std::deque<tle::KVCacheEvent> events; @@ -151,6 +256,8 @@ void KVCacheEventManager::worker() // If there's still too many events, take from the front of the events queue. mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end()); + + // Notify the empty condition variable to wake up any waiting threads mEmptyCV.notify_one(); } } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 4202ba348a..d5fa982a37 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -504,8 +504,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority, - std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())} , mTokensPerBlock{tokensPerBlock} , mEventManager{std::move(eventManager)} @@ -530,7 +529,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks... mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, - onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enableHashKey, enablePartialReuse, + onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, copyOnPartialReuse); } @@ -573,8 +572,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority, - std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mDataType{dtype} , mWindowSize{windowSize} , mNumPrimaryBlocks{blocksInPrimaryPool} @@ -596,7 +594,6 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)} , mReusedTokens{0.0} , mTotalInputTokens{0.0} - , mEnableHashKey{enableHashKey} , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} { @@ -920,50 +917,6 @@ void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } -void WindowBlockManager::addBlockToHashMap(BlockPtr const& block) -{ - if (!mEnableHashKey) - { - return; - } - auto range = mContextBlocksByHash.equal_range(block->getHash()); - for (auto it = range.first; it != range.second; ++it) - { - if (it->second == block) - { - // TODO: change to assert when reused block is added only once - TLLM_LOG_TRACE( - "Block %d by %zx exists", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - return; - } - } - TLLM_LOG_TRACE( - "Add block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - mContextBlocksByHash.emplace(block->getHash(), std::move(block)); -} - -void WindowBlockManager::removeBlockFromHashMap(BlockPtr const& block) -{ - if (mContextBlocksByHash.empty() || block->getBlockKey().uniqueTokens.empty()) - { - // Hash key not enabled / Empty block - return; - } - auto range = mContextBlocksByHash.equal_range(block->getHash()); - TLLM_LOG_TRACE( - "Remove block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - for (auto it = range.first; it != range.second; ++it) - { - if (it->second == block) - { - mContextBlocksByHash.erase(it); - return; - } - } - // TODO: should be unreachable - TLLM_LOG_DEBUG("Trying to remove block %d by %zx that is not in hash map", block->getBlockId(), block->getHash()); -} - void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock); @@ -1104,7 +1057,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const& matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(), matchingBlockId); - addBlockToHashMap(matchingBlock); } searchRoot = nullptr; // no matching needed for following blocks } @@ -1114,7 +1066,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const& mEvictionPolicy->claimBlock( matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId); - addBlockToHashMap(matchingBlock); searchRoot = matchingBlock; } onboardBlock(matchingBlock); @@ -1145,7 +1096,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const& ++blockItr; } freeBlock->setHash(); - addBlockToHashMap(freeBlock); ++mMissedBlocks; } } @@ -1169,7 +1119,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const& ++blockItr; } freeBlock->setHash(); - addBlockToHashMap(freeBlock); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi); } @@ -1369,9 +1318,7 @@ void WindowBlockManager::storeBlocks( if (oldHash != newHash) { TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash); - removeBlockFromHashMap(block); block->setHash(newHash); - addBlockToHashMap(block); } searchRoot = block; } @@ -1408,7 +1355,6 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); - removeBlockFromHashMap(block); } } @@ -1473,7 +1419,6 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence) if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block, true); - removeBlockFromHashMap(block); } // Remove block from allocated blocks allocatedBlocks.pop_back(); @@ -1616,7 +1561,6 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); - removeBlockFromHashMap(block); } } // Remove stored block ids in sequence @@ -1654,8 +1598,7 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size : KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength, - enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, false, enablePartialReuse, - copyOnPartialReuse) + enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse) { } @@ -1682,8 +1625,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority, - std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mMaxBeamWidth(maxBeamWidth) , mDataType(dtype) , mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end())) @@ -1693,10 +1635,9 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), - enableHashKey, enablePartialReuse, copyOnPartialReuse) + enablePartialReuse, copyOnPartialReuse) // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} - , mEnableHashKey{enableHashKey} { TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow) != maxAttentionWindowVec.end()); @@ -1716,12 +1657,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority, - std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, - std::move(eventManager), enableHashKey, enablePartialReuse, copyOnPartialReuse) + std::move(eventManager), enablePartialReuse, copyOnPartialReuse) { } @@ -2085,30 +2025,6 @@ void KVCacheManager::addSequence( llmRequest->mRequestId); } mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize); - if (mEnableHashKey && llmRequest.has_value() && beamWidth == 1) - { - constexpr SizeType32 beamIdx = 0; - auto const& blockIds = sequence.getCacheBlockIds(windowSize).at(beamIdx); - auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); - auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>( - uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), true); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - auto tokensPerBlock = static_cast<size_t>(getTokensPerBlock()); - for (size_t i = 0; i < blockIds.size(); i++) - { - auto const& block = mBlockManager.getBlockById(blockIds[i], windowSize); - if (i < blockKeys.size()) - { - block->setBlockKey(blockKeys[i], blockKeys[i].uniqueTokens.size() == tokensPerBlock); - } - else - { - block->setBlockKey({}, false); - } - block->setHash(); - mBlockManager.addBlockToHashMap(block, windowSize); - } - } } cacheBlockOffsets(sequence, windowSize); } @@ -2127,10 +2043,13 @@ void KVCacheManager::addSequence( void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { auto const requestId = llmRequest.mRequestId; - auto& sequence = getSequence(requestId); - if (mEnableBlockReuse && !sequence.isCyclic() && !llmRequest.isDummyRequest()) + if (mSequences.find(requestId) != mSequences.end()) { - mBlockManager.storeContextBlocks(sequence, llmRequest); + auto& sequence = getSequence(requestId); + if (mEnableBlockReuse && !sequence.isCyclic() && !llmRequest.isDummyRequest()) + { + mBlockManager.storeContextBlocks(sequence, llmRequest); + } } } diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index a9a4aec5df..dcebc9c3ac 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -365,4 +365,10 @@ void LlmRequest::moveLoraWeightsToGpu(runtime::BufferManager const& manager) mLoraWeights = gpuLoraWeights; } +void LlmRequest::removeLoraTensors() +{ + mLoraWeights.reset(); + mLoraConfig.reset(); +} + } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 810edd6f45..22756f2552 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -236,7 +236,7 @@ void MLACacheFormatter::format(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (connections.size() > 1) @@ -433,7 +433,7 @@ void MLACacheFormatter::unformat(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (pickUpConnections.size() > 1) @@ -583,6 +583,28 @@ void MLACacheFormatter::unformat(TransferSession& session) return false; } + int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism; + int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; + int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + + if (selfPPSize == destPPSize) + { + return true; + } + if (selfNumLayers % selfPPSize != 0) + { + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d", + selfNumLayers, selfPPSize); + return false; + } + if (destNumLayers % destPPSize != 0) + { + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ", + destNumLayers, destPPSize); + return false; + } + return true; } } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp index f513f2a3a1..cc62bd3eb0 100644 --- a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp @@ -591,10 +591,9 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (llmRequest->getLoraTaskId().has_value()) { - auto taskId = llmRequest->getLoraTaskId().value(); try { - return mHostLoraCache->determineNumPages(taskId); + return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value()); } catch (std::runtime_error& e) { @@ -602,16 +601,6 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe { return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value()); } - if (!llmRequest->getLoraWeights().has_value()) - { - auto const reqId = llmRequest->mRequestId; - std::string errMsg - = "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task " - + std::to_string(taskId) + " that's not found in LoRA CPU cache." - " Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization," - " so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported."; - throw PeftTaskNotCachedException(errMsg); - } throw; } } diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 4a5ddb8928..08cb4d407c 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -693,7 +693,7 @@ std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::c kvCacheConfig.getEventBufferMaxSize() > 0 ? std::make_unique<kv_cache_manager::KVCacheEventManager>(kvCacheConfig.getEventBufferMaxSize()) : nullptr, - false, kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()); + kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()); reshapeKvTensors(kvCacheManager->getOffsetTableDimensions()); @@ -1866,9 +1866,9 @@ void TrtGptModelInflightBatching::setupDecoderStep( auto const logitsType = mRuntime->getEngine().getTensorDataType("logits"); auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] - = (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, - mRuntime->getBufferManager(), logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(), - *mDecoder->getDecoderStream(), getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers); + = (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType, + inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(), + mOperatingBeamWidth, buffers.mMedusaBuffers); auto const localBatchSize = batchSlots->getSize(); if (localBatchSize > 0) diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index cb252a44d2..7d1118dd6a 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -55,6 +55,7 @@ struct FusedQKVMaskedAttentionDispatchParams T const* qkv_bias; T const* relative_attention_bias; bool const* attention_mask; + float const* attention_sinks; float const* logn_scaling_ptr; int const* cache_indir; void* context_buf; @@ -71,6 +72,7 @@ struct FusedQKVMaskedAttentionDispatchParams RotaryScalingType rotary_embedding_scale_type; float rotary_embedding_scale; float const* rotary_embedding_inv_freq_cache; + float2 const* rotary_embedding_cos_sin_cache; float rotary_embedding_short_m_scale; float rotary_embedding_long_m_scale; int rotary_embedding_max_positions; @@ -225,6 +227,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.output = generationsParams.context_buf; xqaParams.qkv = generationsParams.attention_input; xqaParams.cache_indir = generationsParams.cache_indir; + xqaParams.attention_sinks = generationsParams.attention_sinks; xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant; xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig; xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths; @@ -275,7 +278,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr; xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens; xqaParams.is_fp8_output = mFP8ContextFMHA; - xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr); + xqaParams.fp8_out_scale + = ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr); // Parameters required for FP4 output. xqaParams.output_sf = generationsParams.context_buf_sf; xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale; @@ -596,6 +600,7 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type; params.rotary_embedding_scale = input_params.rotary_embedding_scale; params.rotary_embedding_inv_freq_cache = input_params.rotary_embedding_inv_freq_cache; + params.rotary_embedding_cos_sin_cache = input_params.rotary_embedding_cos_sin_cache; params.rotary_embedding_short_m_scale = input_params.rotary_embedding_short_m_scale; params.rotary_embedding_long_m_scale = input_params.rotary_embedding_long_m_scale; params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions; @@ -620,6 +625,9 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS params.attention_mask = input_params.attention_mask; params.attention_mask_stride = input_params.attention_mask_stride; + // Attention sinks. + params.attention_sinks = input_params.attention_sinks; + // The slope of linear position bias per head, e.g., ALiBi. if (input_params.linear_bias_slopes != nullptr) { @@ -729,10 +737,29 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo; size_t const qk_buf_float_size = mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length; - size_t const fp8_qkv_buffer_size - = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + + int const num_total_qkv_elements + = max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + + size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } + size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; // Each token holds (batch_idx, token_idx_in_seq) int2. @@ -1342,10 +1369,26 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea size_t const qk_buf_float_size = mEnableContextFMHA ? 0 : sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length; - size_t const fp8_qkv_buffer_size - = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + int const num_total_qkv_elements + = params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length; size_t const encoder_padding_offset_size @@ -1353,8 +1396,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea // Each token holds (batch_idx, token_idx_in_seq) int2. size_t const tokens_info_size = sizeof(int2) * params.num_tokens; size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0; - size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0; - size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0; + size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0; + size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0; // cp workspace size upper bound size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1); @@ -1601,6 +1644,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea params.mla_param->cache_type = cache_type; params.mla_param->cu_q_seqlens = cu_q_seqlens; params.mla_param->quant_scale_kv = params.kv_scale_orig_quant; + // Set BMM scales for FP8 context computation + params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr; + params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr; + params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr; + // Set additional scales for context phase + params.mla_param->quant_scale_o = params.attention_output_orig_quant; + params.mla_param->quant_scale_q = params.kv_scale_orig_quant; + params.mla_param->quant_scale_kv = params.kv_scale_orig_quant; + params.mla_param->dequant_scale_q = params.kv_scale_quant_orig; + params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig; + params.mla_param->host_bmm1_scale + = 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim))); if (mPagedContextFMHA && mPagedKVCache) { TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr, @@ -1679,8 +1734,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea // TODO: set it correctly for contiguous kv buffer (cross-attention). fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens; // Device buffer pointers. - fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer) - : reinterpret_cast<void const*>(attention_input); + fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast<void const*>(fp8_qkv_buffer) + : reinterpret_cast<void const*>(attention_input); fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_); // TODO: add contiguous kv buffer (cross-attention). fmhaParams.kvPtr = nullptr; @@ -1691,6 +1746,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea fmhaParams.outputPtr = mCpSize > 1 ? gatherOutBuffer : params.context_buf; // only use [totalLength, h / cpSize, Dh] fmhaParams.outputSfPtr = params.context_buf_sf; + fmhaParams.attentionSinksPtr = params.attention_sinks; fmhaParams.packedMaskPtr = params.attention_packed_mask; if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>) { @@ -2220,6 +2276,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride; dispatch_params.attention_mask = params.attention_mask; dispatch_params.attention_mask_stride = params.attention_mask_stride; + dispatch_params.attention_sinks = params.attention_sinks; dispatch_params.max_distance = max_distance; dispatch_params.cache_indir = params.cache_indir; dispatch_params.context_buf = mCpSize > 1 ? mhaOutput : params.context_buf; // @@ -2267,6 +2324,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType; dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale; dispatch_params.rotary_embedding_inv_freq_cache = params.rotary_inv_freq; + dispatch_params.rotary_embedding_cos_sin_cache = params.rotary_cos_sin; dispatch_params.rotary_embedding_short_m_scale = mRotaryEmbeddingShortMscale; dispatch_params.rotary_embedding_long_m_scale = mRotaryEmbeddingLongMscale; dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions; @@ -2477,7 +2535,7 @@ int AttentionOp::initialize() noexcept } // FP8 FMHA should be used with fp8 workflow together. - if (mFP8ContextFMHA) + if (mFP8ContextFMHA || mFP8ContextMLA) { data_type = DATA_TYPE_E4M3; } @@ -2510,6 +2568,11 @@ int AttentionOp::initialize() noexcept fmhaParams.dataTypeOut = DATA_TYPE_BF16; fmhaParams.dataTypeKv = DATA_TYPE_BF16; } + if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache()) + { + fmhaParams.dataTypeKv = DATA_TYPE_E4M3; + fmhaParams.dataTypeOut = DATA_TYPE_BF16; + } // TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to // bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime. fmhaParams.forceFp32Acc = false; @@ -2563,7 +2626,7 @@ int AttentionOp::initialize() noexcept // Deepseek-V2 Generation needs a differ fmha with different argumments if (mIsMLAEnabled) { - mEnableXQA = (mSM == kSM_120); + mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA; if (mUseTllmGen) { Data_type qDataType = DATA_TYPE_FP32; @@ -2826,6 +2889,7 @@ std::string AttentionOp::toString() const ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl; ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl; ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl; + ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl; ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl; ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl; ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl; diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index fb71c06d57..25d95dfea2 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -65,6 +65,8 @@ public: T const* qkv_bias = nullptr; // Attention mask input, which has shape of [batch_size, attention_mask_stride]. bool const* attention_mask = nullptr; + // Attention sinks with shape of [num_heads_q] float. + float const* attention_sinks = nullptr; // Rotary inv_freq cache buffer to avoid re-computing. float const* rotary_inv_freq = nullptr; // Rotary cos sin cache buffer to avoid re-computing. @@ -386,6 +388,7 @@ public: bool mPosShiftEnabled = false; bool mPagedContextFMHA = false; bool mFP8ContextFMHA = false; + bool mFP8ContextMLA = false; bool mFP8GenerationMLA = false; bool mDenseContextFMHA = false; bool mHasFullAttentionMask = false; diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index f748022941..59c9d2fffe 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE() return forceDeterministic; } +bool getEnvMOEDisableFinalizeFusion() +{ + static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION"); + return moeDisableFinalizeFusion; +} + bool getEnvForceDeterministicAttention() { static bool const forceDeterministic @@ -386,7 +392,7 @@ size_t getEnvAllReduceWorkspaceSize() return workspaceSize; } -std::string getEnvKVCacheTransferOutputPath() +std::string const& getEnvKVCacheTransferOutputPath() { static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or(""); return outputPath; diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index 5e29dfaca7..f5c0d854ba 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -76,7 +76,7 @@ bool getEnvDisableKVCacheTransferOverlap(); bool getEnvEnableReceiveKVCacheParallel(); -std::string getEnvKVCacheTransferOutputPath(); +std::string const& getEnvKVCacheTransferOutputPath(); bool getEnvTryZCopyForKVCacheTransfer(); @@ -86,6 +86,9 @@ bool getEnvForceDeterministic(); // Force deterministic behavior for MoE plugin. bool getEnvForceDeterministicMOE(); +// Disable finalize fusion in MoE plugin +bool getEnvMOEDisableFinalizeFusion(); + // Force deterministic behavior for attention plugin. bool getEnvForceDeterministicAttention(); diff --git a/cpp/tensorrt_llm/common/stringUtils.cpp b/cpp/tensorrt_llm/common/stringUtils.cpp index 75052ad4fa..283dec8842 100644 --- a/cpp/tensorrt_llm/common/stringUtils.cpp +++ b/cpp/tensorrt_llm/common/stringUtils.cpp @@ -34,6 +34,8 @@ void fmtstr_(char const* format, fmtstr_allocator alloc, void* target, va_list a size_t constexpr init_size = 2048; char fixed_buffer[init_size]; auto const size = std::vsnprintf(fixed_buffer, init_size, format, args0); + va_end(args0); + TLLM_CHECK_WITH_INFO(size >= 0, std::string(std::strerror(errno))); if (size == 0) { diff --git a/cpp/tensorrt_llm/common/workspace.h b/cpp/tensorrt_llm/common/workspace.h index 1406e82133..0dd32ed16d 100644 --- a/cpp/tensorrt_llm/common/workspace.h +++ b/cpp/tensorrt_llm/common/workspace.h @@ -20,7 +20,8 @@ namespace tensorrt_llm::common { -std::uintptr_t constexpr kCudaMemAlign = 128; +// CuBLAS >= 12.9.1 requires 256-byte alignment. +std::uintptr_t constexpr kCudaMemAlign = 256; inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) { diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp index c10df82d54..53dc9e053a 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp @@ -27,6 +27,78 @@ namespace cutlass::gemm::collective::detail { +using namespace cute; + +typedef uint32_t __nv_fp4x8_storage_t; +typedef uint32_t __nv_bf16x2_storage_t; +typedef cutlass::uint128_t __nv_bf16x8_storage_t; + +constexpr int int4_group_size = 128; +constexpr int mxfp4_group_size = 32; + +inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code) +{ + unsigned res = 0; + + asm volatile( + "{\n" + "prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(res) + : "r"(lo), "r"(hi), "r"(select_code)); + + return res; +} + +__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index) +{ + const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654 + const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210 + + __nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index); + + return lut_res; +} + +__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8) +{ + __nv_bf16x8_storage_t bf16x8_raw = {0, 0}; + __nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw); + + unsigned zero_padding = 0x00000000U; + + unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U; + unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U); + + __nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654 + __nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210 + + bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0 + bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2 + bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4 + bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6 + + __nv_bf16x2_storage_t bf16x2_0to1_bits; + + __nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1 + __nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0 + + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0 + bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits; + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2 + bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits; + + h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5 + l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4 + + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4 + bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits; + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6 + bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits; + + return bf16x8_raw; +} + template <class Collective> struct MixedGroupedGemmInputUtils { @@ -46,6 +118,7 @@ private: static constexpr auto KernelConversionMode = Collective::KernelConversionMode; static constexpr auto ModeHasScales = Collective::ModeHasScales; static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; + static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable; public: static constexpr auto elements_per_smem_scale() @@ -239,6 +312,27 @@ public: } } + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template <class EngineIn, class LayoutIn, class EngineOut, + class LayoutOut> + CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries + Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>&& dst) + { + fp4tobf16_lookup_table_convert(src, dst); + } + + template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut> + CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( + Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>& dst) + { + + // View the input as reg + auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0); + auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0); + + dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_); + } + /// Utilities to dequantize A. template <class Layout> CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) @@ -253,7 +347,6 @@ public: static_check_scale(flatten(Layout{})); } - // dequantize_A_kblock is here!!! template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts> CUTLASS_DEVICE static void dequantize_A_kblock(Tensor<EngineIn, LayoutIn> const& tCrA_load, Tensor<EngineOut, LayoutOut>& tCrA_mma, cute::tuple<Ts...>& partitioned_extra_info, int const k_block) @@ -288,8 +381,6 @@ public: } else if constexpr (UseScaleLookupTable) { - // this path - constexpr int num_elements = decltype(size(src))::value; static_assert(is_same_v<RealSwappedElementA, cutlass::int4b_t>, "Lookup table only supports int4 being the quant type now."); @@ -424,7 +515,6 @@ public: static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>); static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>); using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; Tensor src = tCrA_load(_, _, k_block); Tensor dst = tCrA_mma(_, _, k_block); @@ -441,7 +531,14 @@ public: CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + if constexpr (UseFP4ToBF16LookupTable) + { + fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i)); + } + else + { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + } } } diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp deleted file mode 100644 index 09ae3e013e..0000000000 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp +++ /dev/null @@ -1,568 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" - -#include "cute/numeric/numeric_types.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" - -#include "cutlass_extensions/arch/copy_red_global.hpp" -#include "cutlass_extensions/util/gather_tensor.hpp" - -#include "cutlass/epilogue/collective/builders/sm90_builder.inl" -#include "cutlass/epilogue/collective/builders/sm90_common.inl" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias, - class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R, - class CopyOpR2G> -class EpilogueMoeFusedFinalize -{ -public: - using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; - using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; - - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementIntermediate = typename ThreadEpilogueOp::ElementD; - - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t<StrideC>; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t<StrideD>; - - static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer"); - static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer"); - - using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>; - using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>; - using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>; - static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; - - using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); - - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - - struct SharedStorage - { - alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D; - }; - - struct TensorMapStorage - { - }; - - struct Arguments - { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C{}; - StrideC dC{}; - ElementD* ptr_D{}; - StrideD dD{}; - ElementBias const* ptr_bias; - StrideBias dBias{}; - ElementScale const* ptr_scale; - StrideScale dScale{}; - int64_t const* group_offset{}; - int32_t const* scatter_index{}; - cutlass::FastDivmod num_rows_in_final_output; - }; - - using Params = Arguments; - - // - // Methods - // - - template <class ProblemShape> - static constexpr Params to_underlying_arguments( - ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) - { - return args; - } - - template <class ProblemShape> - static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) - { - return 0; - } - - template <class ProblemShape> - static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, - void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) - { - return cutlass::Status::kSuccess; - } - - template <class ProblemShape> - CUTLASS_HOST_DEVICE static bool can_implement( - [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) - { - bool implementable = true; - if (problem_shape.is_host_problem_shape_available()) - { - // Check alignment for all problem sizes - for (int i = 0; i < problem_shape.groups(); i++) - { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M, N, K, L] = problem_shape_MNKL; - implementable = implementable - && cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{}); - } - } - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " - "reduction instruction.\n"); - } - return implementable; - } - - CUTLASS_HOST_DEVICE - EpilogueMoeFusedFinalize(Params const& params_) - : params(params_) - { - } - - CUTLASS_DEVICE - bool is_source_needed() - { - // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. - return params.ptr_C != nullptr - && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); - } - - template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout, - class TiledMma, class ResidueMNK> - CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma, - ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) - { - using namespace cute; - using X = Underscore; - - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - auto synchronize = [&]() - { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); - - CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - - // Batches are managed by using appropriate pointers to C and D matrices - int32_t const mock_L = 1; - int32_t const mock_l_coord = 0; - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op(params.thread, l_coord); - - SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf); - - Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); - Tensor sD = as_position_independent_swizzle_tensor(sD_); - - // Function to scatter output rows - auto& num_rows = params.num_rows_in_final_output; - auto read_scatter_map = tensorrt_llm::cutlass_extensions::IndexedGather( - make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); - auto get_scatter_idx = [&](auto i) - { - auto scatter = read_scatter_map(i); - int quot, rem; - num_rows(quot, rem, scatter); - return rem; - }; - - // Represent the full output tensor - ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; - auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) - Tensor mD_mnl = tensorrt_llm::cutlass_extensions::make_gather_tensor( - make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) - - // Use fake shape for bias, it doesn't matter - bool const is_bias_needed = params.ptr_bias != nullptr; - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); - Tensor mScale_mnl = make_tensor( - make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); - - Tensor gC_mnl - = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl - = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - Tensor gBias_mnl - = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gScale_mnl - = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) - Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) - - Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Get the smallest tiled copy we can use to retile the accumulators - TiledCopy tiled_copy_C_atom - = make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma); - TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); - - auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) - Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) - - // Make a tiled copy vectorized along major direction of D - auto tiled_s2r = [&]() - { - if constexpr (cutlass::gemm::detail::is_k_major<StrideD>()) - { - constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{}, - Layout<Shape<_1, Int<AlignmentD>>>{}); - } - else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>()) - { - constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{}, - Layout<Shape<Int<AlignmentD>, _1>>{}); - } - else - { - static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout."); - } - }(); - - auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); - Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // Allocate intermediate registers for a single subtile - Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - - // Make an identity coordinate tensor for predicating our output MN tile - Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // epilogue subtile loop - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) - { - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) - { - int mma_m = (epi_m * epi_tile_m) / mma_tile_m; - int mma_n = (epi_n * epi_tile_n) / mma_tile_n; - Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); - - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) - { - tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); - } - - copy(tiled_r2s, tRS_rD, tRS_sD); - synchronize(); - - copy(tiled_s2r, tSR_sD, tSR_rD); - synchronize(); - - Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); - Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); - Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); - Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); - Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); - - if (epilogue_op.is_source_needed()) - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - } - } - } - -private: - Params params; -}; - -namespace detail -{ - -template <class Element, class MaxVec> -constexpr auto get_vectorized_atomic_add_op() -{ - using namespace cute; - - auto constexpr MaxVecSize = size(MaxVec{}); - - if constexpr (is_same_v<Element, cutlass::half_t>) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_F16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_F16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM70_RED_ADD_NOFTZ_F16x2{}; - } - else - { - return SM70_RED_ADD_NOFTZ_F16{}; - } - } - else if constexpr (is_same_v<Element, cutlass::bfloat16_t>) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM90_RED_ADD_NOFTZ_BF16x2{}; - } - else - { - return SM90_RED_ADD_NOFTZ_BF16{}; - } - } - else - { - // non-vectorized atomic add for all other types until supported - return TypedAtomicAdd<Element>{}; - } -} - -} // namespace detail - -template <class Arch, class TileShape, class ElementC, class StrideC, class ElementD, class StrideD, - class ElementAccumulator, class ElementCompute, class ElementBias, class StrideBias, class ElementScale, - class StrideScale> -struct EpilogueMoeFusedFinalizeBuilder -{ - - // assuming cooperative kernel schedule - using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); - using EpilogueTile = Shape<_128, EpiTileN>; - - // Output of linear combination is ElementCompute instead of ElementD - // since we will be doing more computate on it, no need to cast yet. - using ThreadEpilogueOp - = cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>; - - using SmemLayoutAtomD - = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>()); - using CopyAtomR2S - = decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator, EpilogueTile>()); - using CopyAtomS2R = DefaultCopy; - using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>()); - - template <class Base, class EpilogueOp> - struct TmaWarpSpecializedAdapterWithSmemStorageImpl : Base - { - // We need to override this one using declaration because otherwise we double up on the smem - using TensorMapStorage = typename EpilogueOp::TensorMapStorage; - - // using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>; - - CUTLASS_HOST_DEVICE - TmaWarpSpecializedAdapterWithSmemStorageImpl( - typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) - : Base(params) - { - } - - CUTLASS_DEVICE auto load_init([[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx) - { - return cute::make_tuple(nullptr); - } - - CUTLASS_DEVICE auto store_init([[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx, [[maybe_unused]] int32_t warp_group_idx) - { - return cute::make_tuple(nullptr); - } - - // Dummy methods to perform different parts of TMA/Tensormap modifications - - template <bool IsLoad, class ProblemShapeMNKL> - CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape, - [[maybe_unused]] int32_t next_batch, [[maybe_unused]] int32_t warp_group_idx) - { - } - - template <bool IsLoad> - CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t warp_group_idx) - { - } - - template <bool IsLoad> - CUTLASS_DEVICE void tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) - { - } - }; - - template <class EpilogueOp> - using TmaWarpSpecializedAdapterWithSmemStorage = TmaWarpSpecializedAdapterWithSmemStorageImpl< - std::conditional_t<Arch::kMinComputeCapability >= 100, detail::Sm100TmaWarpSpecializedAdapter<EpilogueOp>, - detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>>, - EpilogueOp>; - - using CollectiveOp = TmaWarpSpecializedAdapterWithSmemStorage< - EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale, - StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp new file mode 100644 index 0000000000..3571906a64 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp @@ -0,0 +1,547 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +template < + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm90ScatterPtrArray { + + using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{}))))); + using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{})); + + using ElementIndex = int32_t; + // TODO: more generic treatment, or pass StrideIndex via template param? + using StrideIndex = conditional_t<cutlass::gemm::detail::is_mn_major<StrideOutput>(), Stride<_0,_1,_0>, Stride<_1,_0,_0>>; + + struct SharedStorage {}; + + struct Arguments { + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + bool use_reduction = true; + }; + + struct Params { + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store + bool use_reduction = true; + }; + + template <class ProblemShape> + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return { + args.ptr_out, + args.dOut, + args.ptr_index, + cutlass::FastDivmod(args.index_modulo), + args.use_reduction + }; + } + + template <class ProblemShape> + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template <class ProblemShape> + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template <class ProblemShape> + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template <class... Args> + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class ArgsTuple + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple) + : args_tuple(std::move(args_tuple)) {} + + ArgsTuple args_tuple; + + template <typename ElementAccumulator, typename ElementInput, int FragmentSize> + CUTLASS_DEVICE auto + visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n, + Array<ElementInput, FragmentSize> const& frg_input) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + using ConvertInput = NumericArrayConverter<ElementOutput, ElementInput, FragmentSize, RoundStyle>; + ConvertInput convert_input{}; + + Tensor tC_rOut_frg = recast<Array<ElementOutput, FragmentSize>>(coalesce(tC_rOut)); // (EPI_V) + tC_rOut_frg(epi_v) = convert_input(frg_input); + + return tC_rOut_frg(epi_v); + } + + template <class STensor, class SyncFn, class VTensor> + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + Tensor byte_buffer = recast<uint8_t>(reduction_buffer); + static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v<uint8_t> >= cosize(SmemLayout{}) * sizeof_bits_v<ElementOutput>, + "Not enough space in scratch smem buffer"); + + Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr<ElementOutput>(byte_buffer.data())), SmemLayout{})); + + auto thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut); + Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut); + + auto thread_r2g = tiled_r2g_red.get_slice(thread_idx); + Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n); + Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut); + Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers + + // sanity check for register reuse + CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G"); + + copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi); + sync_fn(); + copy(tRG_sOut_epi, tRG_rOut_epi); + + auto residue = residue_cD; // capturing structured bindings is a C++20 feature + Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n); + auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); }); + + if (use_reduction) { + copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi); + } + else { + copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi); + } + } + }; + + template <class Element, int MaxVecSize> + static constexpr auto get_reduction_op() + { + using namespace cute; + + // For now only support red.add + if constexpr (is_same_v<Element, cutlass::half_t>) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v<Element, cutlass::bfloat16_t>) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd<Element>{}; + } + } + + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); }; + Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1) + Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1) + Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor tC_gOut = sm90_partition_for_epilogue<ReferenceSrc>(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tC_rOut = make_tensor<ElementOutput>(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N) + + auto tiled_r2s = conditional_return<ReferenceSrc>( + make_tiled_copy_S(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy) + ); + + // Vectorization must not exceed alignment and also the number of values per thread in the tile + int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy)); + int constexpr NumValTile = product(take<0,2>(shape(cD_epi))); + int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads); + + // Choose the largest available red.global op and an st.global op with matching vectorization + using CopyOpR2GRed = decltype(get_reduction_op<ElementOutput, MaxVecSize>()); + using CopyOpR2GStg = UniversalCopy<uint_bit_t<Copy_Atom<CopyOpR2GRed,ElementOutput>::NumValSrc * sizeof_bits_v<ElementOutput>>>; + + auto make_tiled_r2g = [&](auto copy_op) + { + using CopyAtomR2G = Copy_Atom<decltype(copy_op),ElementOutput>; + constexpr int VecSize = CopyAtomR2G::NumValSrc; + if constexpr (cutlass::gemm::detail::is_k_major<StrideOutput>()) { + constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout<Shape<Int<ThreadsMinor>, Int<ThreadsMajor>>, Stride<Int<ThreadsMajor>, _1>>{}, + Layout<Shape<_1, Int<VecSize>>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major<StrideOutput>()) { + constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout<Shape<Int<ThreadsMajor>, Int<ThreadsMinor>>, Stride<_1, Int<ThreadsMajor>>>{}, + Layout<Shape<Int<VecSize>, _1>>{}); + } + else { + static_assert(cute::is_void_v<StrideOutput>, "Unsupported D gmem layout."); + } + }; + + auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{}); + auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{}); + + // Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy, + // ensure they have matching layouts/tilers + using TiledR2GRed = decltype(tiled_r2g_red); + using TiledR2GStg = decltype(tiled_r2g_stg); + static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc"); + static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst"); + static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV"); + static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN"); + + auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx); + Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + + auto args_tuple = make_tuple( + cute::move(tC_rOut), + tiled_r2s, + tRG_gOut, + tRG_cD, + tiled_r2g_red, + tiled_r2g_stg, + params_ptr->use_reduction, + args.thread_idx, + args.residue_cD); + + return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple)); + } +}; + +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v<ElementBias_>, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBias + : ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> +{ + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +template< + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>, + int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBiasPerColScaleScatter + : ScaledAccPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> +{ + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOut; + static constexpr int AlignmentAux = AlignmentOutput; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = alpha * acc + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v<ElementBias>, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPtrArray = + Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // alpha * acc + bias + Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int64_t>>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + >; + +template< + class CtaTileShapeMNK, + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>, + int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray = + Sm90EVT<Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>, // scatter store + Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // scale * (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale + Sm90ScaledAccPerRowBiasPtrArray<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> // alpha * acc + bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScale, + class ElementScalar, + int AlignmentBias, + int AlignmentOutput, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC, + StagesD, + FragmentSize, + ReuseSmemC, + DelayTmaStore, + NumEpilogueWarpGroups + >, + fusion::ScaledAccPerRowBiasPerColScaleScatter<GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + > { + + using StrideOutput = cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>; + + using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + StrideOutput, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + >; + using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter< + GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* const* bias_ptr{}; + StrideBias dBias{}; + + using StrideScale = Stride<_0,_1,int64_t>; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + // Nested args not usable due to a compiler bug with constexpr evaluation + // using ScatterArguments = typename Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>::Arguments; + // ScatterArguments scatter{}; + + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + int const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + bool use_reduction = true; + + operator typename Impl::Arguments() const { + return + { // unary op: reduce(scale * (beta * C + (alpha * acc))) + { // binary op: scale * (beta * C + (alpha * acc)) + { scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end binary op + {} // binary args: multiply + }, // end binary op + //scatter // unary args: reduce + { ptr_out, dOut, ptr_index, index_modulo, use_reduction } + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; + +}; + +} // namespace cutlass::epilogue::fusion + +// clang-format on diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp index 1ee109fd64..2332950629 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp @@ -30,37 +30,12 @@ #include "cute/atom/mma_atom.hpp" #include "cute/numeric/arithmetic_tuple.hpp" -#define GROUP_SIZE 128 - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { using namespace cute; -template <int N> -CUTE_HOST_DEVICE void warpgroup_wait_() -{ -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); - asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group<N> without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif -} - -CUTLASS_DEVICE void warpgroup_wait_dispatch(int onthefly_count) -{ - switch (onthefly_count) - { - case 0: warpgroup_wait_<0>(); break; - case 4: warpgroup_wait_<4>(); break; - case 8: warpgroup_wait_<8>(); break; - case 12: warpgroup_wait_<12>(); break; - default: assert(false && "Invalid onthefly_count value"); - } -} - ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop @@ -91,7 +66,7 @@ public: private: template <class T> friend struct detail::MixedGroupedGemmInputUtils; - using CollectiveType = CollectiveMma<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_, + using CollectiveType = CollectiveMmaArrayMixedInput<DispatchPolicy, TileShape_, ElementAOptionalTuple, StrideA_, ElementBOptionalTuple, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_>; using Utils = detail::MixedGroupedGemmInputUtils<CollectiveType>; @@ -146,6 +121,11 @@ public: static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(), "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + static constexpr bool IsMXFP4 = cute::is_same_v<ElementA, cutlass::float_e2m1_t>; + // Group size 128 for int4 weights + // Group size 32 for mxfp4 weights + static constexpr int ScalingGroupSize = IsMXFP4 ? detail::mxfp4_group_size : detail::int4_group_size; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; @@ -268,6 +248,8 @@ public: || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v<ElementScale>; + static constexpr bool UseFP4ToBF16LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale + && cute::is_same_v<ElementA, cutlass::float_e2m1_t> && cute::is_same_v<ElementB, cutlass::bfloat16_t>; static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); @@ -705,7 +687,7 @@ public: { // The real scale_k that actually works // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) @@ -872,7 +854,6 @@ public: } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - // zero copy auto tZgZ = get<2>(extra_input_partitions); auto tZsZ = get<3>(extra_input_partitions); if (cute::elect_one_sync()) @@ -979,7 +960,8 @@ public: return make_tensor_like<RealSwappedElementA>(tCsA(_, _, _, Int<0>{})); } }(); - Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + // tCrB is just a view of the tensor tCsB Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) // @@ -1013,8 +995,8 @@ public: multiply_add<ElementAccumulator> fma; - constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())(); - constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE; + constexpr int NumMMAsPerChunk = ScalingGroupSize / cute::get<0, 1>(tCsB.shape())(); + constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / ScalingGroupSize; cute::array<decltype(make_fragment_like(accum)), NumChunksPerTileK> intermediate_array; constexpr int K_BLOCK_MAX = size<2>(tCrA_load); @@ -1045,8 +1027,6 @@ public: // src: tCrA_load, dst: tCrA_mma Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) @@ -1079,10 +1059,11 @@ public: } } + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { - warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk); warpgroup_fence_operand(intermediate_array[chunk_id_]); // Apply the group-wise scaling @@ -1129,7 +1110,6 @@ public: Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); - warpgroup_wait<K_WAIT_MAX>(); Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); } } @@ -1169,8 +1149,6 @@ public: tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait<K_WAIT_MAX>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, - // so we can release prior barrier if (k_block == K_BLOCK_MAX - 1) { pipeline.consumer_release( @@ -1187,10 +1165,11 @@ public: { // The last k_block + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { - warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk); warpgroup_fence_operand(intermediate_array[chunk_id_]); // Apply the group-wise scaling @@ -1257,7 +1236,6 @@ public: tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait<K_WAIT_MAX>(); if (k_block == K_BLOCK_MAX - 1) { // release prior barrier @@ -1318,7 +1296,7 @@ public: smem_pipe_release.advance(k_tile_count); // Wait on all GMMAs to complete - warpgroup_wait<0>(); + // warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { @@ -1462,7 +1440,7 @@ public: { NonVoidElementScale const* ptr_S = nullptr; // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor tensor_scale = make_tensor( detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride( @@ -1472,7 +1450,7 @@ public: { ElementZero const* ptr_Z = nullptr; // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor tensor_zero = make_tensor( detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride( diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp index e529ffc1fa..a83bf6a083 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -19,7 +19,7 @@ #include "cute/tensor.hpp" #include "cute/util/print.hpp" -namespace tensorrt_llm::cutlass_extensions +namespace cutlass::util { /// Function object that applies an index to its argument @@ -81,7 +81,7 @@ struct CustomStride template <class Div> CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) { - return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div)); + return CustomStride<Func, decltype(cute::safe_div(s.stride_, div))>(s.func_, cute::safe_div(s.stride_, div)); } // Circumvent the requirement on make_layout that shape and stride are integral @@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func)); return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); } -} // namespace tensorrt_llm::cutlass_extensions +} // namespace cutlass::util namespace cute { diff --git a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt index 68d74a8858..7894ec8dd6 100644 --- a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt @@ -1,4 +1,4 @@ -set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01) +set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72) set(NVSHMEM_URL_HASH SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a) @@ -19,8 +19,15 @@ foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) set(CUDA_ARCH_MINOR ${CMAKE_MATCH_2}) set(CUDA_ARCH_POSTFIX ${CMAKE_MATCH_3}) if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9) - list(APPEND DEEP_EP_CUDA_ARCHITECTURES - "${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}") + # The FP4-related conversion instructions in DeepEP require SM100a, SM110a, + # or SM120a. + if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0) + list(APPEND DEEP_EP_CUDA_ARCHITECTURES + "${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}a${CUDA_ARCH_POSTFIX}") + else() + list(APPEND DEEP_EP_CUDA_ARCHITECTURES + "${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}") + endif() endif() endforeach() diff --git a/cpp/tensorrt_llm/executor/executor.cpp b/cpp/tensorrt_llm/executor/executor.cpp index 70ca2be41a..091bb51282 100644 --- a/cpp/tensorrt_llm/executor/executor.cpp +++ b/cpp/tensorrt_llm/executor/executor.cpp @@ -132,10 +132,12 @@ std::optional<std::shared_ptr<KVCacheEventManager>> Executor::getKVCacheEventMan return mImpl->getKVCacheEventManager(); } -KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize) +KVCacheEvent::KVCacheEvent( + size_t eventId, KVCacheEventData data, SizeType32 windowSize, std::optional<SizeType32> attentionDpRank) : eventId{eventId} , data{std::move(data)} , windowSize{windowSize} + , attentionDpRank{attentionDpRank} { } diff --git a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp index 51b047ebd2..21cf314c87 100644 --- a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp +++ b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp @@ -27,6 +27,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co std::optional<size_t> const& hostCacheSize, bool onboardBlocks, std::optional<FloatType> const& crossKvCacheFraction, std::optional<RetentionPriority> secondaryOffloadMinPriority, size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm, + SizeType32 attentionDpEventsGatherPeriodMs, std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults) : mEnableBlockReuse(enableBlockReuse) , mHostCacheSize(hostCacheSize) @@ -36,6 +37,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} , mUseUvm{useUvm} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { if (maxTokens) { @@ -61,6 +63,8 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co { fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value()); } + TLLM_CHECK_WITH_INFO( + mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0"); } bool KvCacheConfig::getEnableBlockReuse() const @@ -128,6 +132,11 @@ bool KvCacheConfig::getUseUvm() const return mUseUvm; } +SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const +{ + return mAttentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse) { mEnableBlockReuse = enableBlockReuse; @@ -204,6 +213,12 @@ void KvCacheConfig::setUseUvm(bool useUvm) mUseUvm = useUvm; } +void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs) +{ + TLLM_CHECK(attentionDpEventsGatherPeriodMs > 0); + mAttentionDpEventsGatherPeriodMs = attentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults) { if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec) diff --git a/cpp/tensorrt_llm/executor/loraConfig.cpp b/cpp/tensorrt_llm/executor/loraConfig.cpp index 058b1a8671..c8499f36d4 100644 --- a/cpp/tensorrt_llm/executor/loraConfig.cpp +++ b/cpp/tensorrt_llm/executor/loraConfig.cpp @@ -27,26 +27,29 @@ LoraConfig::LoraConfig(IdType taskId, std::optional<Tensor> weights, std::option , mWeights(std::move(weights)) , mConfig(std::move(config)) { - if (mWeights.has_value() || mConfig.has_value()) + if (mConfig.has_value()) { - TLLM_CHECK_WITH_INFO(mWeights.has_value() && mConfig.has_value(), - "Request for LoRA inference must have both lora weights and lora config"); - - SizeType32 constexpr expectedWeightsDims = 2; SizeType32 constexpr expectedConfigDims = 2; + TLLM_CHECK_WITH_INFO( + mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions"); + TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU + && mConfig.value().getMemoryType() != MemoryType::kUNKNOWN, + "Expected lora config to be in CPU memory"); + TLLM_CHECK_WITH_INFO( + mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32"); + } + if (mWeights.has_value()) + { + SizeType32 constexpr expectedWeightsDims = 2; + TLLM_CHECK_WITH_INFO( + mConfig.has_value(), "Request for LoRA inference with lora weights must also have lora config"); TLLM_CHECK_WITH_INFO( mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions"); - TLLM_CHECK_WITH_INFO( - mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions"); + TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU && mWeights.value().getMemoryType() != MemoryType::kUNKNOWN, "Expected lora weights to be in CPU memory"); - TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU - && mConfig.value().getMemoryType() != MemoryType::kUNKNOWN, - "Expected lora weights to be in CPU memory"); - TLLM_CHECK_WITH_INFO( - mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32"); TLLM_CHECK_WITH_INFO(mConfig.value().getShape()[0] == mWeights.value().getShape()[0], "Expected dim 0 of lora weights and lora config to have the same size"); diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 65718f0405..38256edbc7 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/cudaStream.h" +#include <cstddef> #include <iostream> #include <memory> #include <type_traits> @@ -1162,10 +1163,11 @@ KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is) auto secondaryOffloadMinPriority = su::deserialize<std::optional<executor::RetentionPriority>>(is); auto eventBufferMaxSize = su::deserialize<size_t>(is); auto useUvm = su::deserialize<bool>(is); + auto attentionDpEventsGatherPeriodMs = su::deserialize<SizeType32>(is); return KvCacheConfig{enableBlockReuse, maxTokens, maxAttentionWindowVec, sinkTokenLength, freeGpuMemoryFraction, hostCacheSize, onboardBlocks, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize, - enablePartialReuse, copyOnPartialReuse, useUvm}; + enablePartialReuse, copyOnPartialReuse, useUvm, attentionDpEventsGatherPeriodMs}; } void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os) @@ -1183,6 +1185,7 @@ void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& su::serialize(kvCacheConfig.getSecondaryOffloadMinPriority(), os); su::serialize(kvCacheConfig.getEventBufferMaxSize(), os); su::serialize(kvCacheConfig.getUseUvm(), os); + su::serialize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), os); } size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) @@ -1202,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority()); totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize()); totalSize += su::serializedSize(kvCacheConfig.getUseUvm()); + totalSize += su::serializedSize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs()); return totalSize; } @@ -2181,6 +2185,237 @@ std::vector<RequestStatsPerIteration> Serialization::deserializeRequestStatsPerI return iterRequestStatsVec; } +// KVCacheEvents deque +std::vector<char> Serialization::serialize(std::deque<KVCacheEvent> const& eventQueue) +{ + // Compute the size of serialized buffer + size_t totalSize = 0; + totalSize += sizeof(size_t); + for (auto const& event : eventQueue) + { + totalSize += su::serializedSize(event); + } + + std::vector<char> buffer(totalSize); + std::stringbuf strbuf(std::ios_base::out | std::ios_base::in); + strbuf.pubsetbuf(buffer.data(), buffer.size()); + std::ostream os(&strbuf); + + su::serialize(eventQueue.size(), os); + for (auto const& event : eventQueue) + { + su::serialize(event, os); + } + return buffer; +} + +std::deque<KVCacheEvent> Serialization::deserializeKVCacheEvents(std::vector<char>& buffer) +{ + std::deque<KVCacheEvent> kvCacheEvents; + su::VectorWrapBuf<char> strbuf(buffer); + std::istream is(&strbuf); + auto numEvents = su::deserialize<std::size_t>(is); + for (std::size_t event = 0; event < numEvents; ++event) + { + kvCacheEvents.emplace_back(Serialization::deserializeKVCacheEvent(is)); + } + return kvCacheEvents; +} + +// KVCacheEvent +size_t Serialization::serializedSize(KVCacheEvent const& event) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(event.eventId); + totalSize += su::serializedSize(event.data); + totalSize += su::serializedSize(event.windowSize); + totalSize += su::serializedSize(event.attentionDpRank); + return totalSize; +} + +void Serialization::serialize(KVCacheEvent const& event, std::ostream& os) +{ + su::serialize(event.eventId, os); + su::serialize(event.data, os); + su::serialize(event.windowSize, os); + su::serialize(event.attentionDpRank, os); +} + +KVCacheEvent Serialization::deserializeKVCacheEvent(std::istream& is) +{ + auto eventId = su::deserialize<IdType>(is); + auto data = su::deserialize<KVCacheEventData>(is); + auto windowSize = su::deserialize<SizeType32>(is); + auto attentionDpRank = su::deserialize<std::optional<SizeType32>>(is); + + return KVCacheEvent{eventId, data, windowSize, attentionDpRank}; +} + +// KVCacheCreatedData +size_t Serialization::serializedSize(KVCacheCreatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.numBlocksPerCacheLevel); + return totalSize; +} + +void Serialization::serialize(KVCacheCreatedData const& data, std::ostream& os) +{ + su::serialize(data.numBlocksPerCacheLevel, os); +} + +KVCacheCreatedData Serialization::deserializeKVCacheCreatedData(std::istream& is) +{ + auto numBlocksPerCacheLevel = su::deserialize<std::vector<SizeType32>>(is); + return KVCacheCreatedData{numBlocksPerCacheLevel}; +} + +// KVCacheStoredData +size_t Serialization::serializedSize(KVCacheStoredData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.parentHash); + totalSize += su::serializedSize(data.blocks); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredData const& data, std::ostream& os) +{ + su::serialize(data.parentHash, os); + su::serialize(data.blocks, os); +} + +KVCacheStoredData Serialization::deserializeKVCacheStoredData(std::istream& is) +{ + auto parentHash = su::deserialize<std::optional<IdType>>(is); + auto blocks = su::deserialize<std::vector<KVCacheStoredBlockData>>(is); + return KVCacheStoredData{parentHash, blocks}; +} + +// KVCacheStoredBlockData +size_t Serialization::serializedSize(KVCacheStoredBlockData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.tokens); + totalSize += su::serializedSize(data.loraId); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.tokens, os); + su::serialize(data.loraId, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is) +{ + auto blockHash = su::deserialize<IdType>(is); + auto tokens = su::deserialize<tensorrt_llm::runtime::VecUniqueTokens>(is); + auto loraId = su::deserialize<std::optional<tensorrt_llm::runtime::LoraTaskIdType>>(is); + auto cacheLevel = su::deserialize<SizeType32>(is); + auto priority = su::deserialize<SizeType32>(is); + + return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority}; +} + +// KVcacheRemovedData + +size_t Serialization::serializedSize(KVCacheRemovedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHashes); + return totalSize; +} + +void Serialization::serialize(KVCacheRemovedData const& data, std::ostream& os) +{ + su::serialize(data.blockHashes, os); +} + +KVCacheRemovedData Serialization::deserializeKVCacheRemovedData(std::istream& is) +{ + auto blockHashes = su::deserialize<std::vector<IdType>>(is); + return KVCacheRemovedData{blockHashes}; +} + +// KVCacheEventDiff +template <typename T> +size_t Serialization::serializedSize(KVCacheEventDiff<T> const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.oldValue); + totalSize += su::serializedSize(data.newValue); + return totalSize; +} + +template <typename T> +void Serialization::serialize(KVCacheEventDiff<T> const& data, std::ostream& os) +{ + su::serialize(data.oldValue, os); + su::serialize(data.newValue, os); +} + +template <typename T> +KVCacheEventDiff<T> Serialization::deserializeKVCacheEventDiff(std::istream& is) +{ + auto oldValue = su::deserialize<T>(is); + auto newValue = su::deserialize<T>(is); + return KVCacheEventDiff<T>{oldValue, newValue}; +} + +// KVCacheUpdatedData +size_t Serialization::serializedSize(KVCacheUpdatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheUpdatedData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheUpdatedData Serialization::deserializeKVCacheUpdatedData(std::istream& is) +{ + auto blockHash = su::deserialize<IdType>(is); + auto cacheLevel = su::deserialize<std::optional<KVCacheEventDiff<SizeType32>>>(is); + auto priority = su::deserialize<std::optional<KVCacheEventDiff<SizeType32>>>(is); + return KVCacheUpdatedData{blockHash, cacheLevel, priority}; +} + +// UniqueToken +size_t Serialization::serializedSize(tensorrt_llm::runtime::UniqueToken const& token) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(token.tokenId); + totalSize += su::serializedSize(token.tokenExtraId); + return totalSize; +} + +void Serialization::serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os) +{ + su::serialize(token.tokenId, os); + su::serialize(token.tokenExtraId, os); +} + +tensorrt_llm::runtime::UniqueToken Serialization::deserializeUniqueToken(std::istream& is) +{ + auto tokenId = su::deserialize<tensorrt_llm::runtime::TokenIdType>(is); + auto tokenExtraId = su::deserialize<tensorrt_llm::runtime::TokenExtraIdType>(is); + return tensorrt_llm::runtime::UniqueToken{tokenId, tokenExtraId}; +} + // String std::string Serialization::deserializeString(std::istream& is) { diff --git a/cpp/tensorrt_llm/executor/serializeUtils.h b/cpp/tensorrt_llm/executor/serializeUtils.h index 8f26c58d62..40b50f9230 100644 --- a/cpp/tensorrt_llm/executor/serializeUtils.h +++ b/cpp/tensorrt_llm/executor/serializeUtils.h @@ -122,6 +122,14 @@ static_assert(hasSerializedSize<GuidedDecodingParams>(size_t())); static_assert(!hasSerializedSize<std::string>(size_t())); static_assert(!hasSerializedSize<std::optional<float>>(size_t())); static_assert(hasSerializedSize<CacheTransceiverConfig>(size_t())); +static_assert(hasSerializedSize<KVCacheEvent>(size_t())); +static_assert(hasSerializedSize<KVCacheCreatedData>(size_t())); +static_assert(hasSerializedSize<KVCacheStoredData>(size_t())); +static_assert(hasSerializedSize<KVCacheStoredBlockData>(size_t())); +static_assert(hasSerializedSize<KVCacheRemovedData>(size_t())); +static_assert(hasSerializedSize<KVCacheEventDiff<SizeType32>>(size_t())); +static_assert(hasSerializedSize<KVCacheUpdatedData>(size_t())); +static_assert(hasSerializedSize<tensorrt_llm::runtime::UniqueToken>(size_t())); template <typename T> size_t serializedSize(T const& data) @@ -219,6 +227,14 @@ static_assert(hasSerialize<ContextPhaseParams>(nullptr)); static_assert(!hasSerialize<std::string>(nullptr)); static_assert(!hasSerialize<std::optional<float>>(nullptr)); static_assert(hasSerialize<CacheTransceiverConfig>(nullptr)); +static_assert(hasSerialize<KVCacheEvent>(nullptr)); +static_assert(hasSerialize<KVCacheCreatedData>(nullptr)); +static_assert(hasSerialize<KVCacheStoredData>(nullptr)); +static_assert(hasSerialize<KVCacheStoredBlockData>(nullptr)); +static_assert(hasSerialize<KVCacheRemovedData>(nullptr)); +static_assert(hasSerialize<KVCacheEventDiff<SizeType32>>(nullptr)); +static_assert(hasSerialize<KVCacheUpdatedData>(nullptr)); +static_assert(hasSerialize<tensorrt_llm::runtime::UniqueToken>(nullptr)); template <typename T> void serialize(T const& data, std::ostream& os) @@ -291,6 +307,22 @@ struct get_variant_alternative_type } }; +template <typename T> +T deserialize(std::istream& is); + +// Helper function to deserialize variant by index using template recursion +template <typename T, std::size_t... Is> +T deserializeVariantByIndex(std::istream& is, std::size_t index, std::index_sequence<Is...> /*indices*/) +{ + T result; + bool found = ((Is == index ? (result = deserialize<std::variant_alternative_t<Is, T>>(is), true) : false) || ...); + if (!found) + { + TLLM_THROW("Invalid variant index during deserialization: " + std::to_string(index)); + } + return result; +} + // Deserialize template <typename T> T deserialize(std::istream& is) @@ -511,6 +543,38 @@ T deserialize(std::istream& is) { return Serialization::deserializeCacheTransceiverConfig(is); } + else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheEvent>) + { + return Serialization::deserializeKVCacheEvent(is); + } + else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheCreatedData>) + { + return Serialization::deserializeKVCacheCreatedData(is); + } + else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheStoredData>) + { + return Serialization::deserializeKVCacheStoredData(is); + } + else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheStoredBlockData>) + { + return Serialization::deserializeKVCacheStoredBlockData(is); + } + else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheRemovedData>) + { + return Serialization::deserializeKVCacheRemovedData(is); + } + else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheEventDiff<SizeType32>>) + { + return Serialization::deserializeKVCacheEventDiff<SizeType32>(is); + } + else if constexpr (std::is_same_v<T, tensorrt_llm::executor::KVCacheUpdatedData>) + { + return Serialization::deserializeKVCacheUpdatedData(is); + } + else if constexpr (std::is_same_v<T, tensorrt_llm::runtime::UniqueToken>) + { + return Serialization::deserializeUniqueToken(is); + } // Optional else if constexpr (std::is_same_v<T, std::optional<typename ValueType<T>::type>>) { @@ -547,23 +611,7 @@ T deserialize(std::istream& is) std::size_t index = 0; is.read(reinterpret_cast<char*>(&index), sizeof(index)); - // TODO: Is there a better way to implement this? - T data; - if (index == 0) - { - using U = std::variant_alternative_t<0, T>; - data = deserialize<U>(is); - } - else if (index == 1) - { - using U = std::variant_alternative_t<1, T>; - data = deserialize<U>(is); - } - else - { - TLLM_THROW("Serialization of variant of size > 2 is not supported."); - } - return data; + return deserializeVariantByIndex<T>(is, index, std::make_index_sequence<std::variant_size_v<T>>{}); } else { diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu index 27d041618e..84710a9636 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu @@ -256,9 +256,9 @@ public: constexpr int SF_VEC_SIZE = 16; using PackedVec = PackedVec<DType>; PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&val); - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt, token_id, - m_access_id_in_token, std::nullopt, m_params.hidden_dim, - reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout); + auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt, token_id, m_access_id_in_token, + std::nullopt, m_params.hidden_dim / SF_VEC_SIZE, reinterpret_cast<uint32_t*>(m_params.scale_out), + m_params.layout); reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id] = cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, m_scale_factor, sf_out); } diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h index dbf45ebe1c..52487b25d4 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h @@ -132,7 +132,7 @@ struct AllReduceFusionParams float rms_eps; float* scale_factor; bool use_oneshot; - FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED; + QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED; cudaStream_t stream; AllReduceFusionPattern pattern; bool trigger_completion_at_end = true; diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu index 2176ba759f..c38abd9578 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu @@ -99,15 +99,15 @@ __device__ struct __attribute__((aligned(32))) LamportFlags uint32_t* offset_access_ptr; uint32_t* buffer_flags; - __device__ explicit LamportFlags(uint32_t* buffer_flags) + __device__ explicit LamportFlags(uint32_t* buffer_flags, uint32_t buffer_size) : offset_access_ptr(&buffer_flags[4]) , buffer_flags(buffer_flags) + , buffer_size(buffer_size) { uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0]; - buffer_size = flag.z; input_offset = flag.x * (buffer_size << 1U); clear_offset = flag.y * (buffer_size << 1U); - num_tokens_prev = flag.w; + num_tokens_prev = flag.z; } __device__ void cta_arrive() @@ -135,7 +135,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0]; buffer_flags[0] = (flag.x + 1) % 3; buffer_flags[1] = (flag.y + 1) % 3; - buffer_flags[3] = num_tokens; + buffer_flags[2] = num_tokens; *(offset_access_ptr) = 0; } } @@ -144,7 +144,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags template <int WORLD_SIZE, typename T> __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens, - int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results) + int buffer_M, int token_dim, int rank, uint32_t buffer_size, uint32_t* buffer_flags, bool wait_for_results) { int elt = blockIdx.y * blockDim.x + threadIdx.x; if (elt >= token_dim) @@ -155,7 +155,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ cudaGridDependencySynchronize(); #endif - LamportFlags flags(buffer_flags); + LamportFlags flags(buffer_flags, buffer_size); // Capture the number of tokens in previous iteration so that we can properly clear the buffer // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up @@ -217,15 +217,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif - - // Similarly clear broadcast buffer here - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) + if (elt < token_dim) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) + // Similarly clear broadcast buffer here + for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt] - = fromFloat<T>(-0.f); + uint32_t clr_token_idx = token + clr_tok * gridDim.x; + if (clr_token_idx < buffer_M) + { + input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt] + = fromFloat<T>(-0.f); + } } } @@ -240,20 +242,24 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { - uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; + uint64_t elt_load_offset = blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; + if (elt_load_offset < token_dim) + { + uint64_t current_pos = blockIdx.x * token_dim + elt_load_offset; - void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; - // We have 2 assumptions here: - // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B - // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) - float2 val = loadfloat2(lamport_ptr); - while (isNegZero(*(T*) &val)) - { - val = loadfloat2(lamport_ptr); - } - if (output_ptr) - { - *((float2*) &output_ptr[current_pos]) = val; + void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; + // We have 2 assumptions here: + // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B + // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) + float2 val = loadfloat2(lamport_ptr); + while (isNegZero(*(T*) &val)) + { + val = loadfloat2(lamport_ptr); + } + if (output_ptr) + { + *((float2*) &output_ptr[current_pos]) = val; + } } } @@ -263,10 +269,11 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ } #define LAUNCH_ALL_REDUCE_KERNEL(WORLD_SIZE, T) \ - TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, \ - reinterpret_cast<T*>(params.output), reinterpret_cast<T*>(params.input), \ - reinterpret_cast<T**>(params.buffer_ptrs_dev), (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, \ - params.token_dim, params.rank, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results)); + TLLM_CUDA_CHECK( \ + cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel<WORLD_SIZE, T>, reinterpret_cast<T*>(params.output), \ + reinterpret_cast<T*>(params.input), reinterpret_cast<T**>(params.buffer_ptrs_dev), \ + (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, params.token_dim, params.rank, \ + params.buffer_size, reinterpret_cast<uint32_t*>(params.buffer_flags), params.wait_for_results)); void twoshot_allreduce_op(AllReduceParams const& params) { @@ -369,20 +376,33 @@ inline __device__ T add(T a, T b) } #define FINAL_MASK 0xffffffff +#define WARP_SIZE 32 template <typename T> __inline__ __device__ T warpReduceSum(T val) { + // Get the actual number of active threads in this warp + int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1))); + unsigned int mask = (1U << active_warp_size) - 1; + #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 + for (int offset = 16; offset > 0; offset >>= 1) + { + if (offset < active_warp_size) + { + val = add<T>(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE)); + } + } return val; } inline __device__ float block_reduce_sum(float val) { - __shared__ float smem[32]; - int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32; + __shared__ float smem[WARP_SIZE]; + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps + val = warpReduceSum(val); if (lane_id == 0) { @@ -391,6 +411,7 @@ inline __device__ float block_reduce_sum(float val) __syncthreads(); val = lane_id < warp_num ? smem[lane_id] : 0.f; val = warpReduceSum(val); + return val; } @@ -410,7 +431,7 @@ __device__ float4 loadfloat4(void const* ptr) template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN> __global__ void __launch_bounds__(128, 1) RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon, - T_IN const* residual, int batch_size, uint32_t* buffer_flags) + T_IN const* residual, int batch_size, uint32_t buffer_size, uint32_t* buffer_flags) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) static bool const LAMPORT = true; @@ -433,7 +454,7 @@ __global__ void __launch_bounds__(128, 1) int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; - LamportFlags flags(buffer_flags); + LamportFlags flags(buffer_flags, buffer_size); T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; cudaTriggerProgrammaticLaunchCompletion(); @@ -598,16 +619,15 @@ __global__ void __launch_bounds__(128, 1) #endif } -template <typename T, int H_DIM> +template <typename T, int H_DIM, int NUM_THREADS> void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T const* gamma, double epsilon, - T const* residual, uint32_t* buffer_flags, int batch, cudaStream_t stream) + T const* residual, uint32_t buffer_size, uint32_t* buffer_flags, int batch, cudaStream_t stream) { // input to rmsnorm is the buffer in the twoshot ar // We should use prenorm output to determine the actual used size float _epsilon{static_cast<float>(epsilon)}; - static constexpr int NUM_THREADS = 128; static constexpr int CGA_THREADS = NUM_THREADS; constexpr int iters = H_DIM / CGA_THREADS; @@ -628,28 +648,34 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons &RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); config.dynamicSmemBytes = shmem_size; TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &RMSNorm<H_DIM, NUM_THREADS, 1, T, T>, prenorm_output, normed_output, - input, gamma, _epsilon, residual, batch, buffer_flags)); + input, gamma, _epsilon, residual, batch, buffer_size, buffer_flags)); } -#define LAUNCH_RMSNORM_KERNEL(T, H_DIM) \ - twoshot_rmsnorm<T, H_DIM>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \ +#define LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS) \ + twoshot_rmsnorm<T, H_DIM, NUM_THREADS>(static_cast<T*>(params.residual_output), static_cast<T*>(params.output), \ static_cast<T const*>(params.input), static_cast<T const*>(params.gamma), params.epsilon, \ - static_cast<T const*>(params.residual), params.buffer_flags, params.batch, params.stream) + static_cast<T const*>(params.residual), params.buffer_size, params.buffer_flags, params.batch, params.stream) void twoshot_rmsnorm_op(RMSNormParams const& params) { auto dtype = params.dtype; + +#define CASE_DISPATCH_RMSNORM(T, H_DIM, NUM_THREADS) \ + case H_DIM: LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS); break; + +#define TYPE_DISPATCH_RMSNORM(T) \ + CASE_DISPATCH_RMSNORM(T, 2048, 128) \ + CASE_DISPATCH_RMSNORM(T, 2880, 120) \ + CASE_DISPATCH_RMSNORM(T, 4096, 128) \ + CASE_DISPATCH_RMSNORM(T, 5120, 128) \ + CASE_DISPATCH_RMSNORM(T, 7168, 128) \ + CASE_DISPATCH_RMSNORM(T, 8192, 128) + if (dtype == nvinfer1::DataType::kFLOAT) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(float, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(float, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(float, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(float, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(float, 8192); break; + TYPE_DISPATCH_RMSNORM(float); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -657,13 +683,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 8192); break; + TYPE_DISPATCH_RMSNORM(__nv_bfloat16); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -671,13 +691,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break; + TYPE_DISPATCH_RMSNORM(__nv_half); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -685,6 +699,8 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype."); } +#undef TYPE_DISPATCH_RMSNORM +#undef CASE_DISPATCH_RMSNORM } } // namespace tensorrt_llm::kernels::mnnvl diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h index ccca256b5a..3a0fb753db 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h @@ -30,6 +30,7 @@ struct AllReduceParams int buffer_M; int num_tokens; int token_dim; + uint32_t buffer_size; void** buffer_ptrs_dev; void* multicast_ptr; void* buffer_flags; @@ -50,6 +51,7 @@ struct RMSNormParams void const* gamma; double epsilon; void* residual; + uint32_t buffer_size; uint32_t* buffer_flags; int batch; int hidden_dim; diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu index 577f4b5ff4..7bc9e326fb 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu @@ -150,8 +150,8 @@ __device__ __forceinline__ void fused_op( constexpr int SF_VEC_SIZE = 16; using PackedVec = PackedVec<DType>; PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&norm_val); - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt /* batchIdx */, - token_id, access_id_in_token, std::nullopt /* numRows */, params.hidden_dim, + auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt /* batchIdx */, token_id, + access_id_in_token, std::nullopt /* numRows */, params.hidden_dim / SF_VEC_SIZE, reinterpret_cast<uint32_t*>(params.scale_out), params.layout); reinterpret_cast<uint32_t*>(params.quant_out)[access_id] = cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, *params.scale_factor, sf_out); diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h index 9ebc7de650..4a35d14bf0 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h @@ -55,7 +55,7 @@ struct AllReduceFusionParams void* rms_gamma; float rms_eps; float* scale_factor; - FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED; + QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED; cudaStream_t stream; }; diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h index 66dc990d18..0f8d0cdabd 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h @@ -216,7 +216,8 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90( extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_64_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); -extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); +extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); +extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); @@ -1820,6 +1821,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, true, false, nullptr}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, true, false, nullptr}, @@ -1833,6 +1835,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, false, nullptr}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, nullptr}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90_kernel", 73984, 384, 64, 0, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90}, @@ -1873,11 +1876,13 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 2, false, true, true, true, false, false, true, false, nullptr}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 2, false, true, true, true, false, false, true, false, nullptr}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 0, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_custom_mask_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90}, @@ -1887,7 +1892,10 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, true, nullptr}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, nullptr}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 2, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_softmax_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 40, 40, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 48, 48, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_48_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90}, @@ -2522,7 +2530,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, nullptr}, @@ -2531,7 +2541,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr}, @@ -3144,7 +3156,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_nl_tiled}, #endif @@ -3756,7 +3770,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_nl_tiled}, #endif @@ -4368,13 +4384,17 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_nl_tiled}, #endif #ifndef EXCLUDE_SM_100 +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_nl_tiled}, #endif @@ -4784,7 +4804,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_softcapping_sm120_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 128, 128, 32, 32, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_kernel_nl", 12288, 128, 128, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl}, @@ -4874,7 +4896,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl}, @@ -4883,7 +4907,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled}, { DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 16, 16, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_causal_sm120_kernel_nl_tiled", 16384, 128, 128, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled}, diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 81208594d0..0000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d5bb139b12206a563daec9fa473dda422319bde5ae5f965d37cf5ca67d325c49 -size 1005546 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 7086ad9f48..0000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c4357a935656d47414a459939720b66311c67213f450168715e1cb0238653768 -size 1066324 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index 0acae9aa71..2ae91e52cd 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0a0671e7cbbed9f51dc0c47e4b970e2f72067d629ff6562c9d65f9cd55c68578 -size 361861 +oid sha256:c709dce149c0f4500539e495c90d1da2d86cec28c4187ee9494b015642e158cf +size 363441 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 4cb6bcd1c1..bce0c66bcf 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5ec9817bebb07483ce29d8d91c45d35c2c05f0101bfa70146fba5a6576a6b825 -size 1091614 +oid sha256:b9170581da010aca67f4bafd9f6f59aaaf5fd1958a1fdd336aa208146599ac06 +size 1094770 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 470904148a..caa735d572 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0540cdb398818ec54a60c34b462c158e169347db73d244d633669d74211696ba -size 1467312 +oid sha256:2147a246067f7ea74ca382fbc8c02a26332479e5205ecfbe08fb84161a3a87ec +size 1483888 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index 281985341d..0b584163a8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:69bdfba64f1faff30ed8389a28b7b9ef37c0d180b1df643722b280011c8f74e8 -size 692990 +oid sha256:279bd48b8ac53690bb4e37dffbe9060428db80c1417ff29c6f4d4a10ab35a7c9 +size 700094 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index 8b8738474d..496df695fc 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c8173308813999ab64ba8236016b23fbfd3f3f1501f61290bf71ea027ead2920 -size 642456 +oid sha256:db5d186ce70d7a94cae2b6619b3449ca557903944beba1ee738d2ee425792d74 +size 652718 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 6ca952af64..c6692932cd 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f41ae066b01b2a9c3b5165535f743461a9a1d559f6fcd0a00a04c554f8a50962 -size 414757 +oid sha256:089a98cf8ab0bbd7530e69821c42220ea02578b740bff62a3e6e33de45209114 +size 416335 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1a973c5d2e..555f626864 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ab0be8e667d459e13135f96469613f1c095e47187b24e5d40c7c57583351a076 -size 1194236 +oid sha256:1f0cc486ec5e9c1720f495a2a5e7c26d42e737694d307d4746a08b6ead5cc225 +size 1197394 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index 8faf85254d..b5884bba55 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:03d86280f76994e2e01d43747cb5c811496b8340d031ebb0c3bdd46437422994 -size 1654394 +oid sha256:398965e34c1a4c747b42d8836c04934daaa43903b7931586ed12120e17a61f76 +size 1672548 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 53f3032a30..696620f879 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:35c5715bcb1a16c343f3a28be105fb6fee1bbca24cf832f71a7d0f20cf9a0b3e -size 365015 +oid sha256:77cbd7d45164d24be73e021bc0a8745b4f021e4369a254e216ee00b36d3c7263 +size 366593 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp index 89a4eaa580..22a4ff75bf 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a3335a8d4b2c0ca63f006c3f957d57aa3f808ef06d4adda322c311a333286d84 +oid sha256:3a3f74fbe72ef54b9c028d957353c1ecbff1d20bcc9619ff17ee37471934a2ab size 1126352 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index 9cb2eb33c2..e0b9335b45 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fdc0bf099862d352b3b765e117437240a82e4749d3efd104881647dd4ea14562 +oid sha256:b3af082c6742f385d0d2c96489ff1de314458eb992d6d5a251c737f8ec912e79 size 644092 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 153555cbe4..ec999849fa 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ccd938df8f78af4eae306c6e9e669599c2baf6f095f956318470063c560fbd3c -size 1091610 +oid sha256:8e26f3b8cc173301b3cf07ba1ca7893b6f140432410b0b298361ecff597604c2 +size 1095556 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index cab205493a..284e084f3d 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ce4d35ab4c7b65476f0dcec635db1791fcb718afd6b3531338712f5b2bc9aa84 -size 1460204 +oid sha256:32220d11bc3542e9edcc36d51b4866bf40044213114d7e237e003afc1fc7c464 +size 1478358 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp index ab21a448f5..69a3f4789c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d088ce37b21d335ba1f92034cf97f78fc968d7fecaa0c4f9ec83a0d5165f1d99 +oid sha256:3ee5ae75df4866d848e90616562345d3740b17b68c90f06329dc074dba5217a9 size 482709 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp index 2fa6ba246e..c19635d688 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:40653ec672098e2cb1f94c473fa67852efcf6b49a6e8109e4fcf39422281acb4 +oid sha256:817ae5c1eb8a8c6f22a76ab0b88075fd3391d06abb7dd6d9ab51206b809cd69d size 657930 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp index ebdb0563ef..a625def240 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:96348957990518db6f51af7c681a71e625dede568cc8f8303dd2de8ad09bfc28 +oid sha256:680734da0abb1c3029dce32e892687f649c4219f66574acb15ab88471f508263 size 677218 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index 7cd5b267e0..1691a77e1f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4687df80ac2fa9454b0564b0a80d78cfaedc2c7796c8f3a1010dd7ebbf722c83 +oid sha256:c27e871dd680022920081c30c5e239613e53b42129680fdb1d17668b5c5ddd9a size 369401 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp index f4da9b9d86..6e7098d6c7 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d8b9985065f5f2c62b74c05f8eed02b1909c96656b26fbd7779cc57a2146b037 -size 947140 +oid sha256:3e1ecaa635067924b692b665241d86e1d8c1d60a19290de7adde1ff2ca7dbeb0 +size 956612 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 8ffdb6589d..c38c3b29fd 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:23599e63b07ad966df921daf3cb97a9ed5cde27eeda0fd96ba5abd835b48f89a -size 590779 +oid sha256:d3018c622303f89c6f22f037ec99eaeaeea9cfe8911e22463b48a22c13116805 +size 592357 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1153714c7e..5d286a73e5 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cd1c452565583b20913d835de9b14c2f19c0cc431bc926ea6c92295362a85bca -size 1813864 +oid sha256:a7a381f2855236f418a40124a5254401c95001d5e15c074a704e22cc7ed89aa2 +size 1818600 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index b6383dcbd5..5290f97cfb 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b20de2c6bb3081564ddfbf7ece80fb2c17e66f4e7ff0e0969da4e4655e90d1ec -size 2407418 +oid sha256:9bb49ace4dedc4faa3de2b9c22e09db0f3990129ce7ab4afb6419c38a5d48a16 +size 2427152 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 3713748af5..cb3d89f070 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:33a0e8bb2391128e688e5c6356f09a5ed189ce5c1bcdeef4efc0ce0415dc2849 -size 555245 +oid sha256:9769d7cb9754718798be515c84c45ff48e43322573f3f12e31c2e42e99d8dbd4 +size 557613 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp index 795d4d68fc..de925119b3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4b014f41b1cfdf6ed2729778841213a36440191eb3c087346a02c21510bd3f0e -size 665794 +oid sha256:134f4a73e0e6b02b717319ec49e3b3ea0a585cad385a1f300e6c5761f12de9d7 +size 671320 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 5c8dbe22b2..64bb52e0df 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bd77afeb7dcd1ff8d6be80788b20e92e4fbc8c3026ba12d1d522c99316754a7c -size 1740442 +oid sha256:7935b0f053a79a7e620c0efe274fa5b4c840fc9c6e439a381c4d380446e1cb68 +size 1744388 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp index ee1a46c9bc..87d96af432 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b674707d02aac297b66d523de8b11618ca1598c49eeaf7ce9b1c9d516ce95c4b -size 2247958 +oid sha256:74ecbbaa19b2efe97a3b12c488f0e03c2102f16c460239df4bfc19976fc4365e +size 2266902 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp index 349c2efdfe..15ad1d62a9 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7556f88488e05ee669e763b839afa1b7690060cfa9d8482d419c0ca336df9352 +oid sha256:813265d25709bd2d39982efbaf092c9163b124bd990fccab505b3c22134522aa size 595585 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp index 2ccc55f144..4e62255a62 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ac9d879aa0c70967bb3a79cd7034998baf43a544c0dd4444ebddeb76e78df5ae +oid sha256:dd36195c01bf7c2a2013d5f31d2e74c2579c471385d7b45be7e35ea2f0652608 size 908162 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp index ec1ef8aae9..10ee7b3d8c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e781c0278fc46142f578ae51bfeb38767e89d9c25b92023215948f99dd1d3ed +oid sha256:31d4d6dca68c4632d1f435e9179582cfe2ad7a75ee0f7625ee67b0044c914f10 size 1371512 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp index d904de0acb..407d34a655 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d608e9e3ec460d2a38f43067a7d7a2dd408e068db690806bbafb11007e175336 +oid sha256:6570d3ee7b651dec797e82b31eb21fd3261c6e2639fb7c9b157f251bf98bb3bf size 1419662 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp index 798e8482b4..d6b829a9a0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9c1e1d300866c6425c2495e550230051debdca0a7eb85874ae33c0c2de8a81cb +oid sha256:88b972677c5436b90fe85870278e3b23d6f709608f99295bddf0be3861d95d1a size 1419662 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp index bbcce09e72..7cac9a8325 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:132d83639e34af1b431abdcb3f09542d0389030b85752e18a3ae221ead7d24a3 +oid sha256:d975f605d62c3070d6cf72f6114d98642c520e66989ed2d2845c3213e921ebf7 size 1965880 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp index 83287a0376..9dd7d6bf8e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4a96710f6c691580c2363c187a75fd436f5e6be732810a1a45182ce72dc52d1e +oid sha256:ef5a2728cbd3241f45f3d8285c91a818e11b2a9fedf322f343a9461d31a6ad30 size 1380182 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp index 0062377934..1b6d6cddf5 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a6339f008f451d030aa36a6b3fac7179e7534f7f2474d641fa0ebfbf487074e7 +oid sha256:16b5f3d3f8760dabc0849217cf11edf18d19896dda475a5fc233bbfd444faf33 size 1401494 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp index 0d719af97a..90decb8793 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:57ebcae2b70fc28881f2b3969868d64c203ef4a9cbc9588a9e28051c5f5b6849 +oid sha256:cbacb235f39adaeabd68e2fc46c51aac6ca26cdf96293a6a7eb60b5be40640ef size 1401494 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp index ceab132d42..5628ced1f3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5e2a4ce1b944feb2b3ed535943089a2d5968bf523b149885df78f7fa4bd7e835 +oid sha256:e6f3e068435339a64d47673f8018b66c202f6259d68e0a97a4a30acb7505a7fd size 1935872 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp index 2780675d9d..552a78df4f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f5d456b30f89ad05ba5b852fabcffb3f8269913d83ef8c0e4e319f2243dee54d +oid sha256:7c2d7ab0692de5405b26d19a0c57d720285366ac12a8550bbabca1613cce7f0c size 305897 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp index 2aa3fd4b0a..ca2d2a604d 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:85593d3c2fecb6842a72952c6dcbde19a70e6b26245829d279ca50bb391eb636 +oid sha256:91a26adfddc0bcaf8b42249f59f1a0b9f74be0f82c7378fe4b56f3a2fa3d4bf1 size 290109 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp index b050acbb5a..da475b4a2d 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:69cd61bd8334d2109067ef0460a91b8dba4c2cb07392eb636d72d025ccb15bf9 +oid sha256:6ef79c9e2e2d8bba55d7803dc8dc147b5d8babc29e906a43407a8722bbd8d939 size 498507 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp index e741d50f4c..09b401a003 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0427b7729ce3cfa652a4595d04f936a947febec8f2c96ce33eed7cbaaa05613e +oid sha256:0eef025f8e8581868b02bcea37ff225afebcbb2966450fb29fb0e32ac54eccd4 size 668214 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp index eee064e280..0c6a45eacc 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:321bcd81b8965c8dfc08682f775508ae18e3ff711490ee8dff5fe56c20f74843 +oid sha256:abb2857ffb85cc36aae90ebb674635dffee2b2c5f7ad1ea81bb8002b65d5a0f8 size 711628 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp index 33f4d9cab3..9ecb64bd23 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa77d3789c0ca314689125ec303a8af76554120a708a4b63395c69b7aad07f04 +oid sha256:49a3661535314b139e2794fe16f6f3e0a8d45742b68ea59ba99a9113068adf2c size 752698 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp index 3138343090..d836cccd03 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa35aa70d0fa304c776c076a1a189d32a054d3f696dac5d99018085d1108c73b +oid sha256:d76fb6c4f8bb2de687bc5f9f275389356934119c1f0db9983dcf0ec7b68c6197 size 748726 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp index ca7815f710..79e1e96e9b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d1a702d456b5acf279487dd810e3e33efdd1c7bd82530ceb5a32ad30ec30396c +oid sha256:be8ee89f4489c430d0ff6e9c6cf4e07379ac05abf468d47e34e084ad594b2037 size 946060 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp index 8bb9403c51..3c8b2528fc 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:558aa7d42de329c49361c94c4baef16738304b21b6adbe675d77c7819ef37660 +oid sha256:aa4be8ca2dd52e56c9a6af76b90ac353d217fad5fa931b21129ac5a811b5283a size 489823 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp index 0754f76695..22fce024ea 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7b5baa6048e6c33e74c6d343eb7c76252ff2e534fe467b3189af12b5d64af37c +oid sha256:cb0482b768a40bc7f8a86fa23a84bab62fb82c205f3237ff60becda50cbafc90 size 489823 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp index 68de134acb..c02b557e7f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e17cb191ad092e6db255ea503e49ea883ed56322fc58ed8d68710f6687376c1f +oid sha256:95b1796f4e7c905eca82ed3691427025f68e765797440b962b0114a5ab32b1d7 size 500083 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp index 3ebcc110ec..cbc081aae2 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bfca5660a931e08941347f7a0aefa82c214940e8eaa6b6d89cfded621f34a490 +oid sha256:2d9f13977fc865e716f1f35dfdb222a38000b224ff7394134230ed5c88119947 size 496125 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp index c0c882331e..cc613cc08d 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fffd2cd799953808034d7e7b89a57d4fede24db124bfb0d3938188177acbdfeb +oid sha256:007e32a06fcac853159dc5786940447281c57ba70406d38beb6f089fd037053d size 182023 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp index 458aa250b4..d8ba524113 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:19ada3a5d449542f103077db8d193bc2293a8f48ccee201e366473964287314c +oid sha256:26241ea5909395116e1b1a0f19cadc448886f6a6ab2b3ba76c092b67cd0148f0 size 182023 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp index 65edc3e52a..0206f71981 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b9c32124cd708aab7da30637d85437da0af9bf2157d163c19c6fe14498698cda +oid sha256:86e4ca60a459117c5e701631fbd3c67ca66e81d177c394c1fc9ad3b66396e69a size 661096 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp index 8213475b06..3444d759b7 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7f248fd42759509c61d20f912ae74dc3a85448a9c8386370ea92492ed9031e80 +oid sha256:770db1f4ec1c2d3c25767593b60cb095e49f7a6eb7abe054bbdec6e72db97f8d size 672936 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp index 75bd11ff6e..b99affa020 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:190fd946ddc7e1b5e9ca2172ec1de39c6288829773d9ce29fe98374256eff566 +oid sha256:0b6428cae2d0c8c813925be9589c94771098cfe5a6d0ff2036104d3e36384b81 size 721900 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp index ed5e241d9e..e93db30f53 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b7cd5976c836bcd75c0cadfe968050ac60bf89b93df021ad6c1681e159c497c5 +oid sha256:36c6932301fe3dc29631c28fcb8cb6b08652103bc7a36fd74a03a8189a1c77e4 size 717928 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp index 44ce0c307f..8f42d5a276 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7c536d725e1d9ebd2cb836dfe3993edcc81101534db6b7f1943c8a9443838bf4 +oid sha256:d858f6dcaf3f49fb3fa18b1c8c20ee1b933e2c8ddd1a429c8d3b5b4d269fb875 size 927892 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp index 0216db308c..0cb2a13410 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b5907da5a2f68c010d44bbbd0d780e097f9625be15b2f85e8dd1f00dd4c31ff9 +oid sha256:7dc92ab65ed0fc5f9d821f52a396a6d55ea9ae37e080eac7ff9e9c14eae741e7 size 631890 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp index c63b37264a..648e3acb00 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9cf14c71134a89ed6ffc83c0b7db06ed10e22b55294dc15ddf7f016427f01033 +oid sha256:d66606a37cfe8eb78ccc3f548a231f770df9f46e70f6d3ba22fb8abe6216480e size 159919 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp index 7d1ac80867..6028cc1f32 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f2b83c70dbc8ab0b3695dab3f4d2069b7ee7119e9140d7860b8c19f59a498589 +oid sha256:b723b296cff04602f64a5da9928e6f9b6a03c5cc608ba9ef7d8055f23f1f4ea2 size 159919 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp index 4041bfc97a..b1ee67b880 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fc8369f5701dceea91d429a713ddcbb4ecb0ad08d3c9042688557ead5f00e9da +oid sha256:d40578a5684262cd8136705367e2c98493ea9b9fcfc123c7efa3ead14017b5b8 size 483493 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp index f0afe3fcf1..4ce3d2dba5 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e9fffff2d13d49613e5f9334a010ca9bcde43b3bb55a792fd97fe2c867760dc +oid sha256:60cc82b9d11c53392de91a7c4c097263c20a56f9b346278c7c9af12ef2bb5fbf size 496123 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp index 03a4b33cef..d24465ed9c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dd3041ba5a52263f7f02d64f1911c50e346151bf529e865c1abf22583abd3e21 +oid sha256:8f685b6b2a0a573953f31fad89fa37e949361db245de69c0c06ce0bbb14eacef size 443285 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp index 6984f3c170..dc49a30627 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:12482099b086249163085e6e3421a61f6e304f865aaf56dd15382614be5e48e7 +oid sha256:834f0f3601c589893a21b957be2864df594f96b34b2cfd6018ada8319986aa21 size 441683 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp index 2bb4cc2582..4763a29923 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bfea1ea1627eaef7b614db08bad00bda8b611c8e466c858e050c0ce2aee2eafb +oid sha256:3d81a070e7ed49f1e1a322d38a757a3505186cf5cbded99814e950e07229a46a size 298049 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp index 7e76c5e13d..c8587a81d3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f828600699faa3a0474085cbbe88d2e0ac7c8e056c976b81a882c3a72682e527 +oid sha256:b9de5bc49d888699da1880d24ccf6a9cb6c0049d7a244d1ae9ab64b7365ecd5a size 296445 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp index 1c1f7bdc42..7d299b8705 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2d4b297922065ecb79b4a1278d048b253b57601d011fc5833a32f9fc1b78e58e +oid sha256:e30ed0df4b0d0b1da1ace5831dc0a7a526e04001b25860f862345c78acff5a43 size 427485 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp index 68394c07c1..47eeb69632 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3fd5305445c9856fbd5d9dfaffdd7f87b9014638f33fb63fb2cb4fce9893b20b +oid sha256:030015dc1811e3dc2ae36ed770f51063a3f46deae42ead5e1523c977b438a133 size 425883 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp index 51778ad0e9..1a5b22eed8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2b7fee97097f799830df2bcb1c782c7ea9018243cbd5cd0e0f47ec299b49db79 +oid sha256:6921a204892e1336cef2a308be38855f3c888e56bd6a16752d2806aa9e93c431 size 1524634 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index 537871847d..834fa7d1c0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8ac2f9270988bc02329ce11ef3413395b2b8cdc55fcf4911d170536c6e618317 -size 403697 +oid sha256:200df98fb2fcc734e8fc012c98c5d78c2061e5718eef6ffd50c2358a3d664197 +size 406065 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 6bf814ac8a..e085961e98 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1234cf31a3a6b84ed25fa0ad6c4df9b53f673f6bac2f639a66086ba50f8717ba -size 1120818 +oid sha256:430194fe07e526ad01a1e0fb43273b240c269215b132c9af248ba386dcbda23e +size 1124766 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 3bebbebcf1..2d56be2925 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0fff300932a16d30844e317ace515a178f159c483e436f6955983b96c5c424c6 -size 1549402 +oid sha256:53a07904a7bfbf82380c96af99c5e24bc86f77906c5d6fdc85ef9720639d76d2 +size 1569136 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index ef64a37682..6d074921cd 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ed10767ec913d314936fc5dbd1fd70c5381a622bf3fcf1590f837da6d3285bca -size 723774 +oid sha256:1ce4d27b11fee3e5f6489510b55613177e174660b6c7a6fb4efed862b62c50d7 +size 731668 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index d0bc52f131..a626899316 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7e7a7a9653a9c4e4e9b0514fc1d70abbb4521c7edbede52568d17d0779d62ffb -size 671662 +oid sha256:3992d7bd34e72089c5cffc4fc6de3f70a3995145b989811f83b00b47c96b5159 +size 681924 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 3056a533d6..d95d392d53 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1e18db0cd4de65e76e30f219d24ec00095fb16005882c43322182c5fa3f59032 -size 445541 +oid sha256:521417177fc0447809c07ff86b58725fedbf1a6b9412ace4c50268a20bc2680d +size 447119 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp index 50d7f1bece..c405f483ae 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9aceb502c1a95f58f1eab515cf2aeac92be6d255ef405008a4fd871fd54e9ba6 +oid sha256:cb063c946558e6928faabb85df9775fecd2b9444b40b3e06cf0f863db80a5ad8 size 1242842 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1a74df1288..e88a310b64 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ec96248452f638bb9ca50d3630dd67caf71322c01b17aff301c4a98eb7e27974 -size 1215548 +oid sha256:31e6b7442b277f5206cc1d70fa6021f36170265b311106281e88b4611d1a5b6b +size 1220284 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index e03f7c2575..0db1249a28 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dabc44860e81532e9b7ecb35773d0ad409d45361e20c9510d24387039999a7c3 -size 1720698 +oid sha256:c1342769efa91794d5bd35ac623b3014738b075b2671441668e2f0d5c1eef78a +size 1739642 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index b1d87c1278..4d68087ca1 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0d9c8d1fe282f46c12898ed4851a2640cb33ba5d75c5fe9da8a988f818a0e733 -size 407639 +oid sha256:a49dd8abcca57a64eb2ab4e00e4e0d26edf68488fb67086a4b466f8e6651522e +size 410007 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp index 2a12ddb711..deb498b1a2 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:849a280994b3fa1f18ca6c3866a16a68a9b02831f134f8dfcf0d34502c1d6772 +oid sha256:a7013b1eea12719ebeaf47facc37ef730bb0d6af03ca2ad890724a25448616a9 size 1102672 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index a2c78e856d..4bf37280a0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e209b01409585433406f8392c77a7398270ee1b58446b728cf74faa6fe1bf9a +oid sha256:a16aeaf5d11a4c25461452b5f3145136b31861ef9c443d7ec82066565275d6f8 size 629884 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 61bbc8d762..0115c2c36f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0a22bb0202916831eced0a44acbab769d5647937155e0a2b5e6d0d0cb83c726f -size 1122394 +oid sha256:a7d4526887fe860e0d9c482fc7fe2cfe646c7a20bc8a0813ce33a01fd9cc733c +size 1125550 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index e0170f8db7..5d1d220755 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:582d17d48c7a751a345f74cc8c74f9b8c05278ddfc185da4906310a4973a9bdb -size 1547030 +oid sha256:b880e78ffc354edb541bd612e543dd894843fc4163f7bd65ce53282892381b8a +size 1566764 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp index 456d75f72f..fbab68022c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:70f02b7329eef7ceeb73dd43c3bf8f6ea6132c593bba6dbbed720d8b8ff0c287 +oid sha256:de26acaa532f197e339b6d5b2a2dd8032d505c9e169fce38000b02b2a4188eff size 603809 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index 0c0712acaf..8315c08084 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f67d4e70c39bf379ed0f3ef73a3690ac64efaee1e7134c793a760924c270f046 +oid sha256:cef5bcfe63650bc924d9e45d2755b50940534999fb4fbad3a8abf0ba73b9245a size 329935 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index f35d06ef06..c57602da24 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c2c284c6cb66207bd204bd1b6abe45aa8bf2e0c92631681861df237b8f849a46 -size 363451 +oid sha256:b332d4c6047c98b504cd3be72cc5028d240621c8e0a3260d64c17804982104db +size 365029 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 73d9547cf2..a0fe210d9b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d3bede327d80be420e7bf011ee1a4156365afff7020bbf5a8434da18cb19fb23 -size 1093202 +oid sha256:a16c23767a2e5efbd7330728ed87af2ec62a7731debe1da557705c6db6d3268e +size 1096360 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 998e46d1f1..3c10c48136 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5ee7695bd5bb0a03eafe29a497060d84caec96ca4d159e99e4f02b99977dd2a6 -size 1469690 +oid sha256:66950bc137b734d509f0574152bcf9cf7efcb17a7483450d5fdbf480e9f83001 +size 1486266 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index a76bf3814f..0b4847611f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cecca7ad5c652989a3008c8219177811ab9c7d617adbbc9ed8548141803c66f5 -size 694578 +oid sha256:bba586d9fe487c49cef2abfbfb0a078dde907d28e04b4d2335018cdb7031879c +size 701682 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index 71a5743dd9..fb1751942e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bd6847c0e897eb794a9b1ff67e64358527fe64c3e01fc214545cf76ec60edc6d -size 644046 +oid sha256:d3e45ab30e471f4649807f5b7640512e2c6678cf623cadfcb26c93eb4ad60ec0 +size 654306 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index ea50fb0631..ca8b31a010 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:118cc6d4a5e3e12ce0f2727361fd1d52d1a49c67d0bd1837c24e528c064a0dd7 -size 415557 +oid sha256:1932937b7f4ad0370341c77a03db133dd676bdf844b13eb45ec10243d1dfd16b +size 417135 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 285c32ec70..85d85fa4d9 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:36d6c97af5fb15f32cd1ff13f53dd98a7d670cb80ee766765f42cc453f730812 -size 1195826 +oid sha256:c11f5d464b0486023b78babfdfe9d2768e4b0d13caeb436d6f73110ede72498c +size 1198982 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index bd266daa63..465fcafece 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7775bbc1b43487236cf7570d2ed900f1c9830eab70aac1fa9dc59c439cc0c687 -size 1657562 +oid sha256:3bac9b40302bbfc6ee5a49e5c45d3238f46cff45619acd1b098d90e758d3ce30 +size 1675716 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 2d3c2887be..c65fa93d24 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:199b1ff3cc3d0ff04477ff8f1e6390dd62b3a7c9dd264cc73ce6c716af20a0f9 -size 366603 +oid sha256:26f09ab86b52c40b283652e555f677850f00902151d17e375e016b9a99a97794 +size 368183 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp index e0073c3730..36bdbdda6b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2e743b470f9607abcbc8b71e7ef67455e6104daf3a80d0bd012a96ecf90a8f18 +oid sha256:960c3f9e4fe46fc6390207ba0ed85ec25435045e2213b60c5d44ea9ab4fa56aa size 1128730 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index 1553e77aee..58a89a84a2 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:366aa4e9f3263f73c4e76c0ea8008c0449b6d89bcade761500af949912786e32 +oid sha256:ac167d89ea3150f7b65614645ef09f13e2543bdc0523c1eddce5bbd9cfd306ee size 644892 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index cd0531dde0..cd64d2fe38 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5b8a8d76e17a24afd7af1dc5e112828f98ace78e3f85a7efaadb0cf1937085cc -size 1093198 +oid sha256:9d0cf59a8114940070448d87d02d9e83d53bb371ca9915c3983e03626d17024e +size 1097144 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index 54fd20f69c..f3194ad186 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aeffa2db467fbae3ace85fae9f31e2b8a7c0923ab349ade42318ae6f55249ac8 -size 1462582 +oid sha256:ff1449b6795f5beda0b6a62e8a1171ce952b07c4e63b607c06f5fedddb2debe9 +size 1480736 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp index 673041f7af..87c5afddec 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ffc92513e64631c33290f1e88e5666f5b85251506d527745c493f2e90da39de4 +oid sha256:cb14ae0271f8a83216f67c111530d3fe1be2231541ded5f992ff45226ae90e69 size 678808 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index c39e7fa450..dad37ebd42 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:faad8cb1e44f5e16f61720966d2a6c9e782461c209cd8000263b50d42093444d +oid sha256:46a0d8e0a9495e03f72526b4ee04fa3d2a2d87984057b44550cabf4ffa745ef4 size 370201 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index e2ee736b49..0000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dd930ed415b0303a973a37550ee33fa4975ad6be0cc58d461370b127f9a90f8e -size 1020542 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 95d9b2bf64..0000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4f2b243127e1ce00a850a10cca104ffc42512711f434fbdf8683eeeb49b8ce42 -size 1056062 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 0c093db643..0000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2ce9cc89b1db7f7e4b76b94cf1c3b04db49a2d86b529b1fc85b19057a99bc9fa -size 1007924 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index c24e239dd0..0000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e176513fa0074d688620299dfca53adc3902491e97ea9b6938a4ceb2fcf17ef5 -size 1068702 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp index 18ceeae41b..e769630511 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp @@ -137,8 +137,9 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams) // Are the input sequences padded ? mKernelParams.is_s_padded = mFixedParams.isSPadded; + // [total_q, h, 2] (max/sum) mKernelParams.softmax_stats_ptr = runnerParams.softmaxStatsPtr; - mKernelParams.softmax_stats_stride_in_bytes = sizeof(float) * mFixedParams.numQHeads; + mKernelParams.softmax_stats_stride_in_bytes = sizeof(float) * 2 * mFixedParams.numQHeads; if (mFixedParams.attentionInputLayout == AttentionInputLayout::PACKED_QKV) { @@ -238,6 +239,9 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams) mKernelParams.packed_mask_ptr = runnerParams.packedMaskPtr; mKernelParams.cu_mask_rows = reinterpret_cast<int const*>(runnerParams.cuMaskRowsPtr); } + TLLM_CHECK_WITH_INFO( + runnerParams.attentionSinksPtr == nullptr || mSM == kSM_90, "The attention sinks is only supported on SM90."); + mKernelParams.attention_sinks_ptr = runnerParams.attentionSinksPtr; mKernelParams.cu_q_seqlens = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr); mKernelParams.tile_id_counter_ptr = reinterpret_cast<uint32_t*>(runnerParams.tileCounterPtr); // TRT doesn't support host scales. Use device scales instead. @@ -294,6 +298,11 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams) = mFixedParams.isSPadded ? runnerParams.b * runnerParams.qSeqLen : runnerParams.totalQSeqLen; mLaunchParams.total_kv_seqlen = mFixedParams.isSPadded ? runnerParams.b * runnerParams.kvSeqLen : runnerParams.totalKvSeqLen; + // Workaround for nvbug 5412456: total_kv_seqlen fallbacks to total_q_seqlen if it's zero. + if (mLaunchParams.total_kv_seqlen == 0) + { + mLaunchParams.total_kv_seqlen = mLaunchParams.total_q_seqlen; + } TLLM_CHECK_WITH_INFO(mFixedParams.headSize > 0, "Head size should be greater than 0."); // Pad head size to next power of 2. @@ -453,9 +462,15 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams) } else { + bool isHopperBF16ContextMLA = (mFixedParams.headSize == mFixedParams.headSizeV + 64) && isSm90 + && mFixedParams.dataType == DATA_TYPE_BF16 && mFixedParams.headSizeV == 128; + // TODO: add support for separate QKV input layout mLaunchParams.supportReturnSoftmaxStats = (runnerParams.softmaxStatsPtr != nullptr && mLaunchParams.flash_attention && mLaunchParams.warp_specialization - && mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV); + && ((!isHopperBF16ContextMLA + && mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV) + || (isHopperBF16ContextMLA + && (mLaunchParams.attention_input_layout == AttentionInputLayout::Q_PAGED_KV)))); } } diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h index 96435cca52..e909886616 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h @@ -263,6 +263,8 @@ struct MHARunnerParams void* outputSfPtr; // The softmax_status ptr for RingAttention. void* softmaxStatsPtr; + // The attention sinks ptr. + float const* attentionSinksPtr; // The packed mask ptr. void const* packedMaskPtr; // The cumulative Q sequence lengths. @@ -352,6 +354,8 @@ struct Fused_multihead_attention_params_v2 KVBlockArrayForContextFMHA paged_kv_cache; // The mask to implement drop-out. void const* packed_mask_ptr; + // The attention sinks. + float const* attention_sinks_ptr; // The O matrix (output). void* o_ptr; // The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h index 934679a944..6758558e27 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h @@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t ONESHOT = 4, TWOSHOT = 5, LOWPRECISION = 6, + MNNVL = 7, + NCCL_SYMMETRIC = 8, }; enum class AllReduceStrategyConfig : int8_t diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt index 861e94174a..3c3574c2d5 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt @@ -232,6 +232,11 @@ if(USING_OSS_CUTLASS_MOE_GEMM) set(MOE_GEMM_SRC_CU_LAUNCHER ${MOE_GEMM_SRC_CU}) list(FILTER MOE_GEMM_SRC_CU_LAUNCHER EXCLUDE REGEX ".*moe_gemm_kernels_.*") list(FILTER MOE_GEMM_SRC_CU INCLUDE REGEX ".*moe_gemm_kernels_.*") + set(MOE_GEMM_SRC_CU_HOPPER_FP4 ${MOE_GEMM_SRC_CU}) + list(FILTER MOE_GEMM_SRC_CU_HOPPER_FP4 INCLUDE REGEX + ".*moe_gemm_kernels_(bf16|fp16)_fp4.*") + list(FILTER MOE_GEMM_SRC_CU EXCLUDE REGEX + ".*moe_gemm_kernels_(bf16|fp16)_fp4.*") set(MOE_GEMM_SRC_CU_FP4 ${MOE_GEMM_SRC_CU}) list(FILTER MOE_GEMM_SRC_CU_FP4 INCLUDE REGEX ".*fp4.*") list(FILTER MOE_GEMM_SRC_CU EXCLUDE REGEX ".*fp4.*") @@ -244,6 +249,10 @@ if(USING_OSS_CUTLASS_MOE_GEMM) add_library(_moe_gemm_launcher OBJECT ${MOE_GEMM_SRC_CU_LAUNCHER}) add_cuda_architectures(_moe_gemm_launcher 89) + add_library(_moe_gemm_hopper_fp4 OBJECT ${MOE_GEMM_SRC_CU_HOPPER_FP4}) + set_cuda_architectures(_moe_gemm_hopper_fp4 90) + process_target(_moe_gemm_hopper_fp4 true false) + add_library(_moe_gemm_fp4 OBJECT ${MOE_GEMM_SRC_CU_FP4}) set_cuda_architectures(_moe_gemm_fp4 100f 103 120f) process_target(_moe_gemm_fp4 false true) @@ -253,8 +262,9 @@ if(USING_OSS_CUTLASS_MOE_GEMM) process_target(_moe_gemm_fp8 true true) add_instantiations(moe_gemm_src ${INSTANTIATION_GENERATION_DIR}/gemm_grouped) - target_link_libraries(moe_gemm_src PRIVATE _moe_gemm_launcher _moe_gemm_fp4 - _moe_gemm_fp8) + target_link_libraries( + moe_gemm_src PRIVATE _moe_gemm_launcher _moe_gemm_hopper_fp4 _moe_gemm_fp4 + _moe_gemm_fp8) target_include_directories( moe_gemm_src PUBLIC ${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 29e1528f1e..696c5def17 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -398,13 +398,12 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100( MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); - // TODO These need a specific epilogue sub tile (128, 64), not EpilogueTileAuto, otherwise they crash + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, @@ -442,7 +441,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100( if (config & CutlassGemmConfig::FP8_ONLY) { tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1}); - // TODO(sklevtsov): re-enable when handled by the MoE GEMM dispatch + // TODO: re-enable when handled by the MoE GEMM dispatch // tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 }); } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h index e6c3a6bbfa..646be2575c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h @@ -27,6 +27,7 @@ enum class ActivationType Silu, Swiglu, Geglu, + SwigluBias, Identity, InvalidType }; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index e89373c457..ba755ca669 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -37,11 +37,6 @@ namespace tensorrt_llm::kernels::cutlass_kernels { -template <class T> -constexpr auto transpose_stride(T const& t) -{ - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t)); -} template <typename AType, typename BType, typename BScaleType, typename OType> struct GroupedGemmInput @@ -72,8 +67,6 @@ struct GroupedGemmInput struct TmaWarpSpecializedGroupedGemmInput { - template <class T> - using TransposeStride = decltype(transpose_stride<T>(T{})); template <class Tag> using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; @@ -86,6 +79,7 @@ struct TmaWarpSpecializedGroupedGemmInput using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand + using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand constexpr static int NVFP4BlockScaleVectorSize = 16; constexpr static int MXFPXBlockScaleVectorSize = 32; @@ -121,6 +115,7 @@ struct TmaWarpSpecializedGroupedGemmInput using StrideB = std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>; + using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>; #ifdef ENABLE_FP8 template <class T> @@ -147,37 +142,26 @@ struct TmaWarpSpecializedGroupedGemmInput StrideC* stride_c = nullptr; void const** ptr_c = nullptr; - struct DefaultEpilogue - { - using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; + // D is used in all cases except fused finalize + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; struct FusedFinalizeEpilogue { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>; - using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>; + using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>; void* ptr_final_output = nullptr; StrideFinalOutput stride_final_output{}; - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; + void const** ptr_bias = nullptr; + float const** ptr_router_scales = nullptr; - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; + int const** ptr_source_token_index = nullptr; + int num_rows_in_final_output = 0; - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; - - size_t num_rows_in_final_output = 0; + bool use_reduction = true; }; - DefaultEpilogue default_epilogue; FusedFinalizeEpilogue fused_finalize_epilogue; enum class EpilogueFusion @@ -210,8 +194,10 @@ struct TmaWarpSpecializedGroupedGemmInput struct INT4GroupwiseParams { - constexpr static int group_size = 128; // Unused, hard-coded to 128 + constexpr static int int4_group_size = 128; + constexpr static int wfp4a16_group_size = 32; bool enabled = false; + bool use_wfp4a16 = false; using SFA = __nv_bfloat16; using SFB = __nv_bfloat16; // Unused using ProblemShapeInt = cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>; @@ -233,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; - static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); + static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); @@ -245,16 +231,15 @@ struct TmaWarpSpecializedGroupedGemmInput return stride_a != nullptr && ptr_a != nullptr; } - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); + void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction); std::string toString() const; }; constexpr bool isGatedActivation(ActivationType activation_type) { - return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu + || activation_type == ActivationType::SwigluBias; } template <typename T, /*The type used for activations/scales/compute*/ @@ -267,6 +252,12 @@ class MoeGemmRunner public: MoeGemmRunner(); +#if defined(ENABLE_BF16) + static constexpr bool use_wfp4a16 + = std::is_same_v<WeightType, __nv_fp4_e2m1> && (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>); +#else + static constexpr bool use_wfp4a16 = std::is_same_v<WeightType, __nv_fp4_e2m1> && std::is_same_v<T, half>; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v<T, __nv_fp8_e4m3> @@ -281,6 +272,7 @@ public: static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; #endif + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool use_fp4 = std::is_same_v<T, __nv_fp4_e2m1>; @@ -305,9 +297,9 @@ public: [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; [[nodiscard]] bool supportsTmaWarpSpecialized() const; - [[nodiscard]] bool isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; - [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; + [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, + ActivationType activation_type, int gemm_n, int gemm_k) const; + [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; size_t getMaxWorkspaceSize(int num_experts) const; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index c7c9a55b95..7d592bed0e 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -87,6 +87,62 @@ struct LoraParams namespace cutlass_kernels { +static inline size_t pad_to_multiple_of_16(size_t const& input) +{ + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter +{ +public: + CubKeyValueSorter(); + + CubKeyValueSorter(int const num_experts_per_node); + + void updateNumExperts(int const num_experts_per_node); + + static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node); + + void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in, + int* values_out, size_t const num_key_value_pairs, cudaStream_t stream); + +private: + static int expertsToBits(int experts); + int num_experts_; + int num_bits_; +}; + +struct ActivationParams +{ + ActivationType activation_type; + float const* swiglu_alpha = nullptr; + float const* swiglu_beta = nullptr; + float const* swiglu_limit = nullptr; + + explicit ActivationParams(ActivationType activation_type) + : activation_type(activation_type) + { + TLLM_CHECK_WITH_INFO(activation_type != ActivationType::SwigluBias, + "SwigluBias is not supported in ActivationParams without swiglu_alpha and swiglu_beta"); + } + + ActivationParams( + ActivationType activation_type, float const* swiglu_alpha, float const* swiglu_beta, float const* swiglu_limit) + : activation_type(activation_type) + , swiglu_alpha(swiglu_alpha) + , swiglu_beta(swiglu_beta) + , swiglu_limit(swiglu_limit) + { + } + + // TODO Port everything properly and get rid of these implicit conversions + operator ActivationType() const + { + return activation_type; + } +}; + /** * \brief Describes what parallelism mode the MoE is using * @@ -392,14 +448,14 @@ public: = 0; virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() = 0; - virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) + virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) = 0; // Aliases for profiling the gemms @@ -410,7 +466,7 @@ public: float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) @@ -439,7 +495,8 @@ public: void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) = 0; virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> @@ -456,13 +513,13 @@ public: virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; - bool use_deterministic_hopper_reduce_ = false; + bool use_fused_finalize_ = true; }; // Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . // Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. // Avoid making several duplicates of this class. -template <typename T, /*The type used for activations*/ +template <typename T, /* The type used for activations */ typename WeightType, /* The type for the MoE weights */ typename OutputType = T, /* The type for the MoE final output */ typename InputType = T, /* The type for the MoE input */ @@ -474,6 +531,13 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface = tensorrt_llm::kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface; using ScaleBiasType = BackBoneType; using Self = CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType>; + +#if defined(ENABLE_BF16) + static constexpr bool use_wfp4a16 + = std::is_same_v<WeightType, __nv_fp4_e2m1> && (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>); +#else + static constexpr bool use_wfp4a16 = std::is_same_v<WeightType, __nv_fp4_e2m1> && std::is_same_v<T, half>; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>) &&!std::is_same_v<WeightType, cutlass::uint4b_t>; @@ -485,6 +549,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; #endif + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool act_fp4 = std::is_same_v<T, __nv_fp4_e2m1>; static constexpr bool weight_fp4 = std::is_same_v<WeightType, __nv_fp4_e2m1>; @@ -539,14 +604,14 @@ public: return RunnerType::getConfigs(sm); } - void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; + void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner, @@ -563,7 +628,7 @@ public: TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids); @@ -591,7 +656,7 @@ public: float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) override @@ -645,7 +710,8 @@ public: void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, @@ -654,7 +720,8 @@ public: alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2), reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output), - reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream); + reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row, + stream); } std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> @@ -679,7 +746,7 @@ public: private: std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> setupTmaWarpSpecializedInputs( - int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, int64_t hidden_size, + int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, int64_t inter_size, int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -696,7 +763,8 @@ private: float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row, + cudaStream_t stream); static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, @@ -726,8 +794,8 @@ private: bool mayHaveFinalizeFused() const { - return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 - && !use_deterministic_hopper_reduce_ && !use_w4afp8; + return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_ + && !use_w4_groupwise; } // TODO: This should eventually take the quant params to give more flexibility @@ -758,7 +826,7 @@ private: WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationType fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); + ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h index b1676993de..0b86afda68 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h @@ -58,7 +58,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream); + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, cudaStream_t stream); template <class OutputType, class GemmOutputType, class ScaleBiasType> void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index efcb4b0e5d..e96132773b 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,21 +26,14 @@ #include "cute/tensor.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/fusion/operations.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -203,33 +196,25 @@ using SafeBF16 = void; TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, \ cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) \ { \ - constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ - /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ using ArchTag = cutlass::arch::ArchTag_; \ + constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ + constexpr static bool IsMXFPX = MXFPX_; \ + constexpr static bool IsBlackwell = ArchTag::kMinComputeCapability >= 100; \ + constexpr static bool IsSM120 \ + = ArchTag::kMinComputeCapability == 120 || ArchTag::kMinComputeCapability == 121; \ + constexpr static bool IsSM103 = ArchTag::kMinComputeCapability == 103; \ + constexpr static bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \ + /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ using T = DataType_; \ using WeightType = WeightType_; \ using OutputType = OutputType_; \ - constexpr static bool IsMXFPX = MXFPX_; \ - using Arch = ArchTag; \ - constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \ - constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \ - constexpr static bool IsSM103 = ArchTag::kMinComputeCapability == 103; \ - constexpr static bool IsWFP4AFP8 \ - = cutlass::platform::is_same<WeightType, SafeFP4>::value && cutlass::platform::is_same<T, SafeFP8>::value; \ constexpr static bool IsFP4 = cutlass::platform::is_same<T, SafeFP4>::value; \ - static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \ - \ - constexpr static bool IsFP8 = cutlass::platform::is_same<T, SafeFP8>::value; \ - \ constexpr static bool IsSM103FP4 = IsSM103 && IsFP4; \ - static_assert(IsSM103 == IsSM103FP4, "SM103 only implemented for fp4"); \ - \ - constexpr static bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \ + /* static_assert(IsSM103 == IsSM103FP4, "SM103 only implemented for fp4"); */ \ using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ using MmaTileShape = cute::Shape<cute::Int<CTA_M_*(Is2SM ? 2 : 1)>, cute::Int<CTA_N_>, \ cute::Int<CTA_K_*(IsSM103FP4 ? 3 : 1)>>; \ using ClusterShape = cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>; \ - \ if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && ArchTag::kMinComputeCapability >= 90 \ && ArchTag::kMinComputeCapability < 100) \ { \ @@ -256,6 +241,11 @@ using SafeBF16 = void; // typename MmaTileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \ // struct TmaWarpSpecializedGroupedGemmInfo \ { */ \ + constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same<WeightType, SafeFP4>::value \ + && cutlass::platform::is_same<T, SafeFP8>::value; \ + static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by Blackwell"); \ + \ + constexpr static bool IsFP8 = cutlass::platform::is_same<T, SafeFP8>::value; \ \ /* TODO Update once mixed input support is added */ \ static_assert(cutlass::platform::is_same<T, WeightType>::value || IsWFP4AFP8, \ @@ -329,8 +319,8 @@ using SafeBF16 = void; // units of elements (up to 16 bytes)*/ \ \ /* D matrix configuration */ \ - using LayoutD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::LayoutD; \ - using StrideD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD; \ + using LayoutD = TmaWarpSpecializedGroupedGemmInput::LayoutD; \ + using StrideD = TmaWarpSpecializedGroupedGemmInput::StrideD; \ constexpr static int AlignmentD \ = 128 / cutlass::sizeof_bits<ElementD>::value; /* Memory access granularity/alignment of D matrix \ // in units of elements (up to 16 bytes) */ \ @@ -348,7 +338,7 @@ using SafeBF16 = void; // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \ // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations \ // >;*/ \ - using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; \ + using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ \ constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \ using EpilogueScheduleSM10x = std::conditional_t<Is2SM, cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm, \ @@ -361,12 +351,12 @@ using SafeBF16 = void; using EpilogueElementC = std::conditional_t<IsSM120, ElementCSafe, ElementC>; \ using EpilogueTensorOp = std::conditional_t<IsBlackwell && IsBlockScaled, \ cutlass::arch::OpClassBlockScaledTensorOp, cutlass::arch::OpClassTensorOp>; \ - using EpilogueSubTile \ - = std::conditional_t<Arch::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \ - cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \ + using EpilogueSubTile = std::conditional_t<ArchTag::kMinComputeCapability == 100 && IsFP4 \ + && CTA_N_ == 256, /* SM100 Exactly */ \ + cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \ /* Epilogue For Default Finalize */ \ using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \ - Arch, EpilogueTensorOp, /**/ \ + ArchTag, EpilogueTensorOp, /**/ \ MmaTileShape, ClusterShape, /**/ \ EpilogueSubTile, /**/ \ ElementAccumulator, ElementAccumulator, /**/ \ @@ -375,18 +365,17 @@ using SafeBF16 = void; EpilogueSchedule>::CollectiveOp; \ \ /* Epilogue For Fused Finalize */ \ - using CollectiveEpilogueFinalize = \ - typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< /**/ \ - Arch, MmaTileShape, /**/ \ - ElementCSafe, StrideC*, /**/ \ - ElementFinalOutput, \ - TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, /**/ \ - ElementAccumulator, /**/ \ - ElementAccumulator, /**/ \ - ElementBias, TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, /**/ \ - ElementRouterScales, \ - TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales /**/ \ - >::CollectiveOp; \ + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \ + ArchTag, EpilogueTensorOp, /**/ \ + MmaTileShape, ClusterShape, /**/ \ + EpilogueSubTile, /**/ \ + ElementAccumulator, ElementAccumulator, /**/ \ + EpilogueElementC, LayoutC*, AlignmentC, /**/ \ + void, LayoutD*, AlignmentD, /**/ \ + EpilogueSchedule, /**/ \ + cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter< /**/ \ + LayoutD, ElementFinalOutput, ElementAccumulator, ElementBias, ElementRouterScales> /**/ \ + >::CollectiveOp; \ \ using CollectiveEpilogue = std::conditional_t<FUSION == EpilogueFusion::FINALIZE, \ CollectiveEpilogueFinalize, CollectiveEpilogueDefault>; \ @@ -429,7 +418,7 @@ using SafeBF16 = void; using MainloopElementB = std::conditional_t<IsBlackwell && IsBlockScaled, ElementBBlockScaled, ElementB>; \ \ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder</**/ \ - Arch, TensorOp, /**/ \ + ArchTag, TensorOp, /**/ \ MainloopElementB, LayoutB*, AlignmentB, /* A & B swapped here */ \ MainloopElementA, LayoutA*, AlignmentA, /**/ \ ElementAccumulator, /**/ \ @@ -441,7 +430,7 @@ using SafeBF16 = void; \ using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; \ /*}; \ - \ \ + // \ // using namespace cute; \ // using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;<ArchTag, T, WeightType, OutputType, \ EpilogueTag, \ @@ -497,7 +486,7 @@ using SafeBF16 = void; TLLM_CHECK(tma_ws_input.ptr_a); \ TLLM_CHECK(tma_ws_input.ptr_b); \ \ - auto make_mainloop_params = [&]() -> MainloopArguments \ + MainloopArguments const mainloop_args = [&] \ { \ if constexpr (IsBlockScaled) \ { \ @@ -517,67 +506,46 @@ using SafeBF16 = void; reinterpret_cast<ElementB const**>(tma_ws_input.ptr_b), tma_ws_input.stride_b, \ reinterpret_cast<ElementA const**>(tma_ws_input.ptr_a), tma_ws_input.stride_a); \ } \ - }; \ - \ - auto const mainloop_params = make_mainloop_params(); \ - \ - using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ - using EpilogueScalars = decltype(EpilogueArguments{}.thread); \ - auto make_epilogue_scalars = [&]() \ + }(); \ + using FusionArguments = typename CollectiveEpilogue::FusionCallbacks::Arguments; \ + FusionArguments fusion_args = [&] \ { \ - if constexpr (IsBlackwell) \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) \ { \ - return construct_if_true<IsBlackwell, EpilogueScalars>(ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, nullptr, \ - tma_ws_input.alpha_scale_ptr_array, nullptr, \ - cute::Shape<_0, _0, int64_t>{ \ - cute::_0{}, cute::_0{}, (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \ - cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \ - } \ - else if (tma_ws_input.alpha_scale_ptr_array) \ - { \ - return construct_if_true<!IsBlackwell, EpilogueScalars>(tma_ws_input.alpha_scale_ptr_array); \ + auto epi_params = tma_ws_input.fused_finalize_epilogue; \ + return construct_if_true<FUSION == EpilogueFusion::FINALIZE, FusionArguments>( \ + ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \ + reinterpret_cast<ElementBias const* const*>(epi_params.ptr_bias), \ + Stride<_1, _0, int64_t>{}, /* bias */ \ + epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */ \ + reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output), \ + epi_params.stride_final_output, epi_params.ptr_source_token_index, \ + epi_params.num_rows_in_final_output, epi_params.use_reduction); \ } \ else \ { \ - return construct_if_true<!IsBlackwell, EpilogueScalars>(ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \ + return construct_if_true<FUSION != EpilogueFusion::FINALIZE, FusionArguments>( \ + ElementAccumulator(1), ElementAccumulator(0), nullptr, nullptr, \ + tma_ws_input.alpha_scale_ptr_array, nullptr, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, Stride<_0, _0, int64_t>{}); \ } \ - }; \ - auto epilogue_scalars = make_epilogue_scalars(); \ - /* TODO ptr_c casts to ElementCSafe** because there is a workaround in CUTLASS */ \ - auto make_epi_args = [&]() \ - { \ - static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ - "Unimplemented fusion provided to TMA WS MoE gemm launcher"); \ + }(); \ \ - if constexpr (FUSION == EpilogueFusion::NONE) \ + using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ + EpilogueArguments epilogue_args = [&] \ + { \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) \ { \ - auto epi_params = tma_ws_input.default_epilogue; \ - return construct_if_true<FUSION == EpilogueFusion::NONE, EpilogueArguments>(epilogue_scalars, \ - nullptr, tma_ws_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), \ - epi_params.stride_d); \ - } \ - else if constexpr (FUSION == EpilogueFusion::FINALIZE) \ - { \ - /* Parameters for fused finalize */ \ - auto epi_params = tma_ws_input.fused_finalize_epilogue; \ return construct_if_true<FUSION == EpilogueFusion::FINALIZE, EpilogueArguments>( \ - epilogue_scalars, /* Parameters to underlying epilogue */ \ - nullptr, tma_ws_input.stride_c, /* C params */ \ - reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output), \ - epi_params.stride_final_output, /* D (output) params */ \ - reinterpret_cast<ElementBias const*>(epi_params.ptr_bias), \ - epi_params.stride_bias, /* Bias params */ \ - epi_params.ptr_router_scales, epi_params.stride_router_scales, /* Router scales */ \ - epi_params.ptr_expert_first_token_offset, /* Offset of this expert's token in the \ - router scales */ \ - epi_params.ptr_source_token_index, /* Index of the source token to sum into */ \ - epi_params.num_rows_in_final_output /* Number of tokens in the output buffer */ \ - ); \ + fusion_args, nullptr, nullptr, nullptr, nullptr); \ } \ - }; \ - EpilogueArguments const epilogue_params = make_epi_args(); \ + else \ + { \ + return construct_if_true<FUSION != EpilogueFusion::FINALIZE, EpilogueArguments>(fusion_args, \ + nullptr, nullptr, reinterpret_cast<ElementD**>(tma_ws_input.ptr_d), tma_ws_input.stride_d); \ + } \ + }(); \ /* EpilogueArguments const epilogue_params = make_epi_args<EpilogueArguments, EpilogueScalars, \ ElementCSafe, ElementD, ElementFinalOutput, ElementBias, FUSION>( \ // tma_ws_input, epilogue_scalars \ @@ -587,7 +555,7 @@ using SafeBF16 = void; 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ \ const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ - tma_ws_input.shape_info, mainloop_params, epilogue_params, hw_info, scheduler_args}; \ + tma_ws_input.shape_info, mainloop_args, epilogue_args, hw_info, scheduler_args}; \ \ size_t calculated_ws_size = gemm.get_workspace_size(args); \ TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \ diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index 764a53b107..9701048a89 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -86,15 +86,14 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, ///////////////////////////////////////////////////////////////////////////////////////////////// // A matrix configuration - // using ElementA = typename TllmToCutlassTypeAdapter<T>::type; - using ElementA = cutlass::float_e4m3_t; + using ElementA = typename TllmToCutlassTypeAdapter<T>::type; using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes) // B matrix configuration - // using ElementB = typename TllmToCutlassTypeAdapter<WeightType>::type; - using ElementB = typename cutlass::int4b_t; + using ElementB_ = typename TllmToCutlassTypeAdapter<WeightType>::type; + using ElementB = std::conditional_t<std::is_same_v<WeightType, cutlass::uint4b_t>, cutlass::int4b_t, ElementB_>; using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of @@ -109,9 +108,13 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>; // Scale configuration - constexpr int PackedScalesNum = get<2>(CTAShape{}) / 128; - using ElementScalePacked - = cutlass::Array<TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA, PackedScalesNum>; + constexpr bool use_wfp4a16 = std::is_same_v<ElementB, cutlass::float_e2m1_t>; + constexpr int group_size = use_wfp4a16 ? cutlass::gemm::collective::detail::mxfp4_group_size + : cutlass::gemm::collective::detail::int4_group_size; + constexpr int PackedScalesNum = get<2>(CTAShape{}) / group_size; + using ElementScale = std::conditional_t<use_wfp4a16, cutlass::float_ue8m0_t, + TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA>; + using ElementScalePacked = cutlass::Array<ElementScale, PackedScalesNum>; using LayoutScale = cutlass::layout::RowMajor; // C/D matrix configuration @@ -171,20 +174,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, Args arguments; decltype(arguments.epilogue.thread) fusion_args; - fusion_args.alpha = 0; + fusion_args.alpha = use_wfp4a16 ? 1 : 0; fusion_args.beta = 0; fusion_args.alpha_ptr = nullptr; fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = inputs.alpha_scales; + fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales; fusion_args.beta_ptr_array = nullptr; // One alpha and beta per each group - fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; - fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1}; cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; hw_info.sm_count = sm_count_; + assert(group_size == int(inputs.groupwise_quant_group_size)); if (workspace_size != nullptr) { const Args args{cutlass::gemm::GemmUniversalMode::kGrouped, @@ -192,10 +196,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, {reinterpret_cast<ElementB const**>(hopper_inputs.ptr_b), hopper_inputs.stride_b, reinterpret_cast<ElementA const**>(hopper_inputs.ptr_a), hopper_inputs.stride_a, reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, int(inputs.groupwise_quant_group_size)}, + hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast<ElementD**>(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast<ElementD**>(hopper_inputs.ptr_d), hopper_inputs.stride_d}, hw_info}; *workspace_size = gemm.get_workspace_size(args); return; @@ -206,10 +209,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, {reinterpret_cast<ElementB const**>(hopper_inputs.ptr_b), hopper_inputs.stride_b, reinterpret_cast<ElementA const**>(hopper_inputs.ptr_a), hopper_inputs.stride_a, reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, int(inputs.groupwise_quant_group_size)}, + hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast<ElementD**>(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast<ElementD**>(hopper_inputs.ptr_d), hopper_inputs.stride_d}, hw_info}; if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu new file mode 100644 index 0000000000..be29019bc6 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_template_dispatch.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_fp4_e2m1, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu new file mode 100644 index 0000000000..f1a885ea77 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_template_dispatch.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template class MoeGemmRunner<half, __nv_fp4_e2m1, half>; +} diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 941b3cf1cf..5a07062f06 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -75,6 +75,7 @@ namespace tensorrt_llm::kernels::cutlass_kernels { +using tensorrt_llm::kernels::cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass; // ============================= Variable batched Gemm things =========================== template <typename T, typename WeightType, typename GemmOutputType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, @@ -99,6 +100,7 @@ struct genericMoeGemmKernelLauncher static_assert(cutlass::platform::is_same<T, WeightType>::value || cutlass::platform::is_same<WeightType, uint8_t>::value + || cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value || cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value); static_assert(arch::kMinComputeCapability < 90, "Sm90+ architecture should use specialized kernels"); @@ -506,8 +508,8 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - if (!tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>() - || (use_w4afp8 && sm != 89)) + if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>() || (use_w4afp8 && sm != 89) + || use_wfp4a16) { return {}; } @@ -596,18 +598,19 @@ int MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getSM() const // currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType> bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsFusedGatedActivation( - bool is_gated_activation, int gemm_n, int gemm_k) const + ActivationType activation_type, int gemm_n, int gemm_k) const { constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; - return is_gated_activation && std::is_same_v<T, WeightType> && !std::is_same_v<T, float> && !use_fp8 - && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; + return (activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu) + && std::is_same_v<T, WeightType> && !std::is_same_v<T, float> && !use_fp8 && (this->getSM() >= 80) + && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; } template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType> bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const + cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const { - return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; + return supportsFusedGatedActivation(activation_type, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; } template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType> @@ -639,26 +642,41 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch( if (sm_ >= 75 && sm_ < 80) { - dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm75, EpilogueTag>( - inputs, multi_processor_count_); - } - else if (sm_ >= 80 && sm_ < 90) - { - if constexpr (use_fp8 || use_w4afp8) + if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) { -#if defined(ENABLE_FP8) - static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>, - "FP8 GEMM Output not supported"); -#endif - - TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>( + dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm75, EpilogueTag>( inputs, multi_processor_count_); } else { - dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>( - inputs, multi_processor_count_); + TLLM_THROW("FP4 data type is not supported on SM < 90"); + } + } + else if (sm_ >= 80 && sm_ < 90) + { + + if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) + { + if constexpr (use_fp8 || use_w4afp8) + { +#if defined(ENABLE_FP8) + static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>, + "FP8 GEMM Output not supported"); +#endif + + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>( + inputs, multi_processor_count_); + } + else + { + dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>( + inputs, multi_processor_count_); + } + } + else + { + TLLM_THROW("FP4 data type is not supported on SM < 90"); } } else if (sm_ >= 90) @@ -674,9 +692,8 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch( } } - if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType, - EpilogueTag>() - && !use_w4afp8) + if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType, EpilogueTag>() + && !use_w4_groupwise) { // We allow both tma warp specialized and SM80 configurations to coexist because for some cases with small // numbers of tokens SM80 is faster. We check here to see which is selected @@ -719,33 +736,39 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch( // Hopper finegrained INT4 WS grouped GEMM if constexpr (use_w4afp8) { - if (inputs.gemm_config.is_tma_warp_specialized) + TLLM_CHECK_WITH_INFO( + inputs.gemm_config.is_tma_warp_specialized, "w4afp8 is only supported for TMA warp specialization"); + // EpilogueTag is ignored + if (inputs.k % 512 == 0) { - // EpilogueTag is ignored - if (inputs.k % 512 == 0) - { - cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, - cutlass_extensions::EpilogueOpDefault, 4>( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else if (inputs.k % 256 == 0) - { - cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, - cutlass_extensions::EpilogueOpDefault, 2>( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else if (inputs.k % 128 == 0) - { - cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, - cutlass_extensions::EpilogueOpDefault, 1>( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else - { - TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k); - } - return; - }; + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, + cutlass_extensions::EpilogueOpDefault, 4>(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else if (inputs.k % 256 == 0) + { + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, + cutlass_extensions::EpilogueOpDefault, 2>(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else if (inputs.k % 128 == 0) + { + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, + cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else + { + TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k); + } + return; + } + + if constexpr (use_wfp4a16) + { + TLLM_CHECK_WITH_INFO( + inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization"); + // EpilogueTag is ignored + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, + cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr); + return; } #endif @@ -798,7 +821,7 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getMaxWorkspaceS template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType> size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspaceSize(int num_experts) const { - if constexpr (use_w4afp8) + if constexpr (use_w4_groupwise) { return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput<T, WeightType, OutputType>( num_experts, multi_processor_count_); @@ -807,8 +830,8 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspace { return 0; } - if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType>() - && !use_w4afp8) + if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType>() && !use_w4afp8 + && !use_wfp4a16) { auto configs = getTmaWarpSpecializedConfigs(sm_); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index bf15aed55f..168c50a8a2 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -406,7 +406,6 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG // SHAPE_CASE(100, 128, 128, 64) // SHAPE_CASE(100, 128, 256, 64) - // SHAPE_CASE(100, 256, 256, 64) DEFAULT_CASE(100) } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h index 4c0ddebf6a..1ee7232c9e 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h @@ -155,10 +155,13 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best // for mixed type gemms. - constexpr int Ktile = 128 * PackedScalesNum / sizeof(T); - TLLM_CHECK(sizeof(T) == 1); + constexpr int Ntile = (std::is_same_v<WeightType, __nv_fp4_e2m1>) ? 64 : 128; + constexpr int Ktile = (std::is_same_v<WeightType, __nv_fp4_e2m1>) ? 128 : 128 * PackedScalesNum / sizeof(T); + TLLM_CHECK(sizeof(T) == (std::is_same_v<WeightType, __nv_fp4_e2m1>) ? 2 : 1); + using _Ntile = Int<Ntile>; using _Ktile = Int<Ktile>; + switch (inputs.gemm_config.tile_config_sm90) { case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: @@ -174,8 +177,8 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( inputs, hopper_inputs, sm_count_, workspace_size); break; case tkc::CutlassTileConfigSM90::CtaShape64x128x128B: - sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _128, _Ktile>>( - inputs, hopper_inputs, sm_count_, workspace_size); + sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, + Shape<_64, _Ntile, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break; // case tkc::CutlassTileConfigSM90::CtaShape64x256x128B: // sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _256, @@ -226,11 +229,14 @@ template <typename T, typename WeightType, typename OutputType> size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_count_) { size_t count = 0; + constexpr int Ktile = (std::is_same_v<WeightType, __nv_fp4_e2m1>) ? 256 : 512; + using _Ktile = Int<Ktile>; + #ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS GroupedGemmInput<T, WeightType, OutputType, OutputType> inputs{}; inputs.num_experts = num_experts; sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, OutputType, - tensorrt_llm::cutlass_extensions::EpilogueOpDefault, Shape<_128, _64, _512>, Shape<_1, _1, _1>, + tensorrt_llm::cutlass_extensions::EpilogueOpDefault, Shape<_128, _64, _Ktile>, Shape<_1, _1, _1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>( inputs, TmaWarpSpecializedGroupedGemmInput{}, sm_count_, &count); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu index 485c19496f..b49dfec999 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu @@ -27,14 +27,14 @@ namespace tensorrt_llm::kernels::cutlass_kernels { -std::array<size_t, 17> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( +std::array<size_t, 20> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( int num_experts, FpXBlockScalingType scaling_type) { size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; size_t stride_a_size = sizeof(StrideA) * num_experts; size_t stride_b_size = sizeof(StrideB) * num_experts; size_t stride_c_size = sizeof(StrideC) * num_experts; - size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + size_t stride_d_size = sizeof(StrideD) * num_experts; size_t ptr_buf_size = sizeof(void*) * num_experts; size_t scale_buf_size = sizeof(float*) * num_experts; @@ -53,9 +53,12 @@ std::array<size_t, 17> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( size_t int4_groupwise_sf_a_size = sizeof(INT4GroupwiseParams::SFA*) * num_experts; size_t int4_groupwise_stride_sf_a_size = sizeof(INT4GroupwiseParams::StrideSFA) * num_experts; + size_t ptr_token_map_size = sizeof(int**) * num_experts; + return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size, sf_a_size, sf_b_size, stride_sf_a_size, - stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size}; + stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size, + ptr_buf_size, scale_buf_size, ptr_token_map_size}; } size_t TmaWarpSpecializedGroupedGemmInput::workspaceSize(int num_experts, FpXBlockScalingType scaling_type) @@ -68,7 +71,7 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i size_t gemm_workspace_size, FpXBlockScalingType scaling_type) { auto buffers = workspaceBuffers(num_experts, scaling_type); - std::array<int8_t*, 17> pointers{}; + std::array<int8_t*, 20> pointers{}; TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); for (int i = 0; i < buffers.size(); i++) { @@ -82,12 +85,12 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i stride_a = reinterpret_cast<StrideA*>(pointers[1]); stride_b = reinterpret_cast<StrideB*>(pointers[2]); stride_c = reinterpret_cast<StrideC*>(pointers[3]); - default_epilogue.stride_d = reinterpret_cast<DefaultEpilogue::StrideD*>(pointers[4]); + stride_d = reinterpret_cast<StrideD*>(pointers[4]); ptr_a = reinterpret_cast<void const**>(pointers[5]); ptr_b = reinterpret_cast<void const**>(pointers[6]); ptr_c = reinterpret_cast<void const**>(pointers[7]); - default_epilogue.ptr_d = reinterpret_cast<void**>(pointers[8]); + ptr_d = reinterpret_cast<void**>(pointers[8]); alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]); @@ -103,28 +106,24 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i int4_groupwise_params.ptr_s_a = reinterpret_cast<INT4GroupwiseParams::SFA const**>(pointers[15]); int4_groupwise_params.stride_s_a = reinterpret_cast<INT4GroupwiseParams::StrideSFA*>(pointers[16]); + fused_finalize_epilogue.ptr_bias = reinterpret_cast<void const**>(pointers[17]); + fused_finalize_epilogue.ptr_router_scales = reinterpret_cast<float const**>(pointers[18]); + fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast<int const**>(pointers[19]); + this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace); this->gemm_workspace_size = gemm_workspace_size; } -void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens) +void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams( + void* final_output, int hidden_size, int num_output_tokens, bool use_reduction) { fused_finalize_epilogue.ptr_final_output = final_output; - fused_finalize_epilogue.ptr_router_scales = router_scales; - fused_finalize_epilogue.ptr_bias = bias; - fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; - fused_finalize_epilogue.ptr_source_token_index = source_token_index; - fused_finalize_epilogue.stride_final_output - = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, - transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); - fused_finalize_epilogue.stride_bias - = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); - fused_finalize_epilogue.stride_router_scales = {}; + fused_finalize_epilogue.stride_final_output = cutlass::make_cute_packed_stride( + FusedFinalizeEpilogue::StrideFinalOutput{}, cute::make_shape(hidden_size, num_output_tokens, 1)); fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; + fused_finalize_epilogue.use_reduction = use_reduction; } std::string TmaWarpSpecializedGroupedGemmInput::toString() const @@ -143,16 +142,13 @@ std::string TmaWarpSpecializedGroupedGemmInput::toString() const ss << "Final Output: " << (PrintType) fused_finalize_epilogue.ptr_final_output; ss << " with Stride: " << fused_finalize_epilogue.stride_final_output; ss << ",\nBias: " << (PrintType) fused_finalize_epilogue.ptr_bias; - ss << " with Stride: " << fused_finalize_epilogue.stride_bias; ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nExpert Offset: " << (PrintType) fused_finalize_epilogue.ptr_expert_first_token_offset; ss << ", Source Map: " << (PrintType) fused_finalize_epilogue.ptr_source_token_index; } else { - ss << "Ptr D: " << (PrintType) default_epilogue.ptr_d; - ss << " with Stride: " << (PrintType) default_epilogue.stride_d; + ss << "Ptr D: " << (PrintType) ptr_d; + ss << " with Stride: " << (PrintType) stride_d; } ss << '\n'; ss << "Alpha scale ptr: " << (PrintType) alpha_scale_ptr_array << "\n"; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 0caf687b56..730840717c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -997,12 +997,12 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { constexpr bool is_fp8 = std::is_same_v<QuantizedType, __nv_fp8_e4m3>; - static constexpr int NumThreadsPerSF = VecSize / CVT_FP4_ELTS_PER_THREAD; + static constexpr int NumThreadsPerSF = VecSize / CVT_ELTS_PER_THREAD; // Quantize the input to FP4 static_assert(std::is_same_v<GemmOutputType, __nv_bfloat16> || std::is_same_v<GemmOutputType, half>); - static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD); + static_assert(ComputeElem::kElements == CVT_ELTS_PER_THREAD); PackedVec<GemmOutputType> packed_vec{}; - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { packed_vec.elts[i].x = static_cast<GemmOutputType>(post_act_val[i * 2 + 0]); packed_vec.elts[i].y = static_cast<GemmOutputType>(post_act_val[i * 2 + 1]); @@ -1013,10 +1013,9 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s = act_sf_flat + getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols, scaling_type); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out - = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF, VecSize>( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, - num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out = cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF>( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); // Do the conversion and set the output and scaling factor auto func = [&]() @@ -1043,7 +1042,7 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s template <int VecSize, int ElementsPerThread> __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id, int64_t elem_idx, int64_t num_cols, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf = true) { static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; @@ -1055,20 +1054,31 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out - = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF, VecSize>( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, - num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out = cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF>( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); if (sf_out) { if (input_sf) { - auto const sf_in - = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF, - VecSize>(std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), - FP4QuantizationSFLayout::SWIZZLED); - *sf_out = *sf_in; + if (swizzled_input_sf) + { + auto const sf_in + = cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF>( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), + QuantizationSFLayout::SWIZZLED); + *sf_out = *sf_in; + } + else + { + auto const sf_in + = cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF>( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), + QuantizationSFLayout::LINEAR); + *sf_out = *sf_in; + } } else { @@ -1155,14 +1165,19 @@ __device__ void computeTmaWarpSpecializedInputStrides( } if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info.default_epilogue.stride_d[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1)); + layout_info.stride_d[out_idx] = cutlass::make_cute_packed_stride( + TmaWarpSpecializedGroupedGemmInput::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1)); } if (layout_info.int4_groupwise_params.enabled) { layout_info.int4_groupwise_params.stride_s_a[out_idx] = cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::StrideSFA{}, - cute::make_shape(gemm_n, gemm_k / 128, 1)); + cute::make_shape(gemm_n, + gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size), + 1)); } } @@ -1170,7 +1185,8 @@ template <class T, class WeightType, class OutputType, class ScaleBiasType> __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGroupedGemmInput& layout_info, int64_t gemm_m, int64_t gemm_n, int64_t gemm_k, int num_tokens_before_expert, int64_t expert, T const* in, WeightType const* weights, TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale, - ScaleBiasType const* bias, OutputType* output, int64_t const out_idx) + ScaleBiasType const* bias, OutputType* output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, int64_t const out_idx) { // The input prior to this contains K elements per token, with `num_tokens_before_expert` tokens layout_info.ptr_a[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k); @@ -1181,12 +1197,28 @@ __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGrouped if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens_before_expert` tokens - layout_info.default_epilogue.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + layout_info.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + } + if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) + { + + layout_info.fused_finalize_epilogue.ptr_source_token_index[expert] + = permuted_row_to_unpermuted_row + num_tokens_before_expert; + layout_info.fused_finalize_epilogue.ptr_router_scales[expert] = router_scales + num_tokens_before_expert; + if (layout_info.fused_finalize_epilogue.ptr_bias != nullptr) + { + layout_info.fused_finalize_epilogue.ptr_bias[expert] = bias + gemm_n * expert; + } } if (layout_info.int4_groupwise_params.enabled) { - layout_info.int4_groupwise_params.ptr_s_a[out_idx] - = safe_inc_ptr(w4a8_weight_scale, expert * (gemm_n * gemm_k / 128)); + // The group size of wfp4a16 is multiplied by 2 because each scale uses 1 byte instead of 2 bytes + layout_info.int4_groupwise_params.ptr_s_a[out_idx] = safe_inc_ptr(w4a8_weight_scale, + expert + * (gemm_n * gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size * 2 + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size))); } } @@ -1199,7 +1231,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir WeightType const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output) + ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row) { // First, compute the global tid. We only need 1 thread per expert. int const expert = blockIdx.x * blockDim.x + threadIdx.x; @@ -1277,12 +1310,12 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir gemm1_in, weights1, reinterpret_cast<TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const*>( quant_params.groupwise.fc1.weight_scales), - bias1, gemm1_output, expert); + bias1, gemm1_output, nullptr, nullptr, expert); computeTmaWarpSpecializedInputPointers(layout_info2, gemm_m, gemm2_n, gemm2_k, num_tokens_before_expert, expert, gemm2_in, weights2, reinterpret_cast<TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const*>( quant_params.groupwise.fc2.weight_scales), - bias2, gemm2_output, expert); + bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert); } template <class T, class WeightType, class OutputType, class ScaleBiasType> @@ -1400,12 +1433,12 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); + layout_info1.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); + layout_info2.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); } } else @@ -1415,10 +1448,10 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali layout_info1.ptr_b[expert] = nullptr; layout_info2.ptr_b[expert] = nullptr; - layout_info1.default_epilogue.ptr_d[expert] = nullptr; + layout_info1.ptr_d[expert] = nullptr; if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; + layout_info2.ptr_d[expert] = nullptr; } } } @@ -1452,8 +1485,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size, int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node, - InputActivationsType const* prequant_scales = nullptr) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, "AWQ and Block Scaling are mutually exclusive"); @@ -1487,7 +1520,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; constexpr int64_t ELEM_PER_THREAD - = (is_nvfp4 || is_mxfp8) ? CVT_FP4_ELTS_PER_THREAD : (128 / sizeof_bits<InputActivationsType>::value); + = (is_nvfp4 || is_mxfp8) ? CVT_ELTS_PER_THREAD : (128 / sizeof_bits<InputActivationsType>::value); // This should be VecSize * 4 elements // We assume at least VecSize alignment or the quantization will fail @@ -1555,7 +1588,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); writeSF<VecSize, ELEM_PER_THREAD>(num_tokens_before_expert, expert, source_row, permuted_row, - elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); + elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf, swizzled_input_sf); dest_row_ptr[elem_index] = in_vec; } } @@ -1656,7 +1689,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, cudaStream_t stream) { #ifdef ENABLE_FP4 TLLM_CHECK_WITH_INFO( @@ -1732,8 +1766,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, permuted_row_to_unpermuted_row, num_rows, hidden_size, k, quant_params.fp4.fc1.act_global_scale, - use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node, - reinterpret_cast<InputActivationsType const*>(prequant_scales)); + use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf, + num_experts_per_node, reinterpret_cast<InputActivationsType const*>(prequant_scales)); } #define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ @@ -1743,8 +1777,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, \ QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, \ TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \ - cudaStream_t stream) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, \ + void const* prequant_scales, cudaStream_t stream) // Instantiate the data types that are used by the external pytorch op INSTANTIATE_EXPAND_INPUT_ROWS(float, float); @@ -1994,8 +2028,8 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro #define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \ template void finalizeMoeRoutingKernelLauncher<OutputT, GemmOutputT, ScaleBiasT>( \ GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, ScaleBiasT const* bias, \ - float const* final_scales, int const* expanded_source_row_to_expanded_dest_row, \ - int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \ + float const* final_scales, int const* unpermuted_row_to_permuted_row, \ + int const* permuted_row_to_unpermuted_row, int const* expert_for_source_row, \ int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \ int64_t const experts_per_token, int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, \ bool const enable_alltoall, cudaStream_t stream); @@ -2007,16 +2041,67 @@ INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float); INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); #endif +// ============================== Activation Adaptors ================================= +template <template <class> class ActFn> +struct IdentityAdaptor +{ + constexpr static bool IS_GLU = false; + float alpha = 1.0f; + float beta = 0.0f; + float limit = std::numeric_limits<float>::infinity(); + + template <class T> + __device__ T operator()(T const& x) const + { + ActFn<T> fn{}; + return fn(x); + } +}; + +template <template <class> class ActFn> +struct GLUAdaptor +{ + constexpr static bool IS_GLU = true; + float alpha = 1.0f; + float beta = 0.0f; + float limit = std::numeric_limits<float>::infinity(); + + template <class T> + __device__ T operator()(T const& gate, T const& linear) const + { + ActFn<T> fn{}; + return fn(gate) * linear; + } +}; + +struct SwigluBiasAdaptor +{ + constexpr static bool IS_GLU = true; + float alpha = 1.0f; + float beta = 0.0f; + float limit = std::numeric_limits<float>::infinity(); + + template <class T> + __device__ T operator()(T const& gate, T const& linear) const + { + cutlass::epilogue::thread::Sigmoid<T> fn{}; + T linear_clamped = cutlass::maximum<T>{}(cutlass::minimum<T>{}(linear, limit), -limit); + T gate_clamped = cutlass::minimum<T>{}(gate, limit); + return gate_clamped * fn(gate_clamped * alpha) * (linear_clamped + beta); + } +}; + // ============================== Gated Activation ================================= constexpr static int ACTIVATION_THREADS_PER_BLOCK = 256; -template <class ActivationOutputType, class GemmOutputType, template <class> class ActFn> +template <class ActivationOutputType, class GemmOutputType, class ActFn> __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutputType const* gemm_result, - int64_t const* num_valid_tokens_ptr, int64_t inter_size) + int64_t const* expert_first_token_offset, int64_t inter_size, int64_t num_experts_per_node, + ActivationParams activation_type) { int64_t const tid = threadIdx.x; int64_t const token = blockIdx.x; - if (num_valid_tokens_ptr && token >= *num_valid_tokens_ptr) + if (token >= expert_first_token_offset[num_experts_per_node]) { return; } @@ -2037,39 +2122,61 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutput int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; int64_t const inter_size_vec = inter_size / ACTIVATION_ELEM_PER_THREAD; - ActFn<ComputeElem> fn{}; + float gate_alpha = 1.0f; + float gate_bias = 0.0f; + float gate_limit = std::numeric_limits<float>::infinity(); + if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit) + { + int expert + = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t) token + 1) - 1; + gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f; + gate_bias = activation_type.swiglu_beta ? activation_type.swiglu_beta[expert] : 0.0f; + gate_limit = activation_type.swiglu_limit ? activation_type.swiglu_limit[expert] + : std::numeric_limits<float>::infinity(); + } + + ActFn fn{}; + fn.alpha = gate_alpha; + fn.beta = gate_bias; + fn.limit = gate_limit; for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { - auto fc1_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index]); + auto linear_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index]); // BF16 isn't supported, use FP32 for activation function auto gate_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + inter_size_vec]); - auto gate_act = fn(gate_value); - output_vec[elem_index] = arrayConvert<ComputeElem, OutputElem>(fc1_value * gate_act); + auto gate_act = fn(gate_value, linear_value); + output_vec[elem_index] = arrayConvert<ComputeElem, OutputElem>(gate_act); } } template <typename ActivationOutputType, typename GemmOutputType> void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_result, - int64_t const* num_valid_tokens_ptr, int64_t inter_size, int64_t num_tokens, ActivationType activation_type, - cudaStream_t stream) + int64_t const* expert_first_token_offset, int64_t inter_size, int64_t num_tokens, int64_t num_experts_per_node, + ActivationParams activation_type, cudaStream_t stream) { int64_t const blocks = num_tokens; int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; - auto* fn = activation_type == ActivationType::Swiglu - ? &doGatedActivationKernel<ActivationOutputType, GemmOutputType, cutlass::epilogue::thread::SiLu> - : &doGatedActivationKernel<ActivationOutputType, GemmOutputType, cutlass::epilogue::thread::GELU>; - fn<<<blocks, threads, 0, stream>>>(output, gemm_result, num_valid_tokens_ptr, inter_size); + auto* fn = (activation_type == ActivationType::Swiglu) + ? &doGatedActivationKernel<ActivationOutputType, GemmOutputType, GLUAdaptor<cutlass::epilogue::thread::SiLu>> + : activation_type == ActivationType::Geglu + ? &doGatedActivationKernel<ActivationOutputType, GemmOutputType, GLUAdaptor<cutlass::epilogue::thread::GELU>> + : activation_type == ActivationType::SwigluBias + ? &doGatedActivationKernel<ActivationOutputType, GemmOutputType, SwigluBiasAdaptor> + : nullptr; + TLLM_CHECK_WITH_INFO(fn != nullptr, "Invalid activation type"); + fn<<<blocks, threads, 0, stream>>>( + output, gemm_result, expert_first_token_offset, inter_size, num_experts_per_node, activation_type); } // ============================== Activation ================================= -template <class T, class GemmOutputType, class ScaleBiasType, template <class> class ActFn, +template <class T, class GemmOutputType, class ScaleBiasType, class ActFn, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType> __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, - int num_experts_per_node, int64_t inter_size, bool gated, float const* fc2_act_global_scale, - bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat) + int num_experts_per_node, int64_t inter_size, float const* fc2_act_global_scale, bool use_per_expert_act_scale, + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, ActivationParams activation_params) { #ifdef ENABLE_FP4 constexpr bool IsNVFP4 = std::is_same_v<T, __nv_fp4_e2m1> @@ -2082,18 +2189,15 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, #endif int64_t const tid = threadIdx.x; - size_t const gated_size_mul = gated ? 2 : 1; - size_t const gated_off = gated ? inter_size : 0; - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif + constexpr bool IsGated = ActFn::IS_GLU; + size_t gated_size_mul = IsGated ? 2 : 1; + size_t gated_off = IsGated ? inter_size : 0; constexpr int64_t VecSize = IsNVFP4 ? TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; // Load 128-bits per thread, according to the smallest data type we read/write constexpr int64_t ACTIVATION_ELEM_PER_THREAD = (IsNVFP4 || IsMXFP8) - ? CVT_FP4_ELTS_PER_THREAD + ? CVT_ELTS_PER_THREAD : (128 / std::min(sizeof_bits<T>::value, sizeof_bits<GemmOutputType>::value)); // This should be VecSize * 4 elements @@ -2104,16 +2208,28 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x) { size_t gemm_result_offset = token * inter_size * gated_size_mul; size_t output_offset = token * inter_size; int64_t expert = 0; - if (bias_ptr || IsNVFP4 || IsMXFP8 || use_per_expert_act_scale) + float gate_alpha = 1.0f; + float gate_beta = 0.0f; + float gate_limit = std::numeric_limits<float>::infinity(); + if (bias_ptr || IsNVFP4 || IsMXFP8 || use_per_expert_act_scale || activation_params.swiglu_alpha + || activation_params.swiglu_beta || activation_params.swiglu_limit) { // TODO this is almost certainly faster as a linear scan expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; + + gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; + gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; + gate_limit = activation_params.swiglu_limit ? activation_params.swiglu_limit[expert] + : std::numeric_limits<float>::infinity(); } size_t act_scale_idx = use_per_expert_act_scale ? expert : 0; @@ -2145,7 +2261,10 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0); int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD; - ActFn<ComputeElem> fn{}; + ActFn fn{}; + fn.alpha = gate_alpha; + fn.beta = gate_beta; + fn.limit = gate_limit; for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto fc1_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]); @@ -2154,17 +2273,22 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, fc1_value = fc1_value + arrayConvert<BiasElem, ComputeElem>(bias_ptr_vec[elem_index + gated_off_vec]); } - auto gate_act = fn(fc1_value); - - if (gated) + auto gate_act = [&]() { - auto gate_mul = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index]); - if (bias_ptr_vec) + if constexpr (IsGated) { - gate_mul = gate_mul + arrayConvert<BiasElem, ComputeElem>(bias_ptr_vec[elem_index]); + auto linear_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index]); + if (bias_ptr_vec) + { + linear_value = linear_value + arrayConvert<BiasElem, ComputeElem>(bias_ptr_vec[elem_index]); + } + return fn(fc1_value, linear_value); } - gate_act = gate_act * gate_mul; - } + else + { + return fn(fc1_value); + } + }(); auto post_act_val = gate_act * quant_scale; @@ -2252,7 +2376,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, template <class T, class GemmOutputType, class ScaleBiasType> void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias, bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size, - int64_t expanded_num_tokens, ActivationType activation_type, QuantParams const& quant_params, + int64_t expanded_num_tokens, ActivationParams activation_type, QuantParams const& quant_params, bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream) { @@ -2275,20 +2399,23 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 auto fn = [&](auto block_scaling_type) { auto fn_list = std::array{ - &doActivationKernel<T, GemmOutputType, ScaleBiasType, cutlass::epilogue::thread::GELU, + &doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>, // Gelu - &doActivationKernel<T, GemmOutputType, ScaleBiasType, cutlass::epilogue::thread::ReLu, + &doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value>, // Relu - &doActivationKernel<T, GemmOutputType, ScaleBiasType, cutlass::epilogue::thread::SiLu, + &doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>, // Silu - &doActivationKernel<T, GemmOutputType, ScaleBiasType, cutlass::epilogue::thread::SiLu, + &doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>, // Swiglu - &doActivationKernel<T, GemmOutputType, ScaleBiasType, cutlass::epilogue::thread::GELU, + &doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>, // Geglu - &doActivationKernel<T, GemmOutputType, ScaleBiasType, cutlass::epilogue::thread::Identity, - decltype(block_scaling_type)::value> // Identity + &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor, + decltype(block_scaling_type)::value>, // SwigluBias + &doActivationKernel<T, GemmOutputType, ScaleBiasType, + IdentityAdaptor<cutlass::epilogue::thread::Identity>, + decltype(block_scaling_type)::value> // Identity }; - return fn_list[static_cast<int>(activation_type)]; + return fn_list[static_cast<int>(activation_type.activation_type)]; }; auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{}; @@ -2325,8 +2452,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset, - num_experts_per_node, inter_size, isGatedActivation(activation_type), quant_params.fp4.fc2.act_global_scale, - use_per_expert_act_scale, fc2_act_sf_flat); + num_experts_per_node, inter_size, quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale, + fc2_act_sf_flat, activation_type); } // ============================== Lora Add Bias ================================= @@ -2726,7 +2853,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab bool const is_gated_activation = isGatedActivation(activation_type); bool const gemm1_using_fused_moe - = moe_gemm_runner_.isFusedGatedActivation(*gemm1_config_, is_gated_activation, inter_size, hidden_size); + = moe_gemm_runner_.isFusedGatedActivation(*gemm1_config_, activation_type, inter_size, hidden_size); bool const gemm1_using_tma_ws = moe_gemm_runner_.isTmaWarpSpecialized(*gemm1_config_); bool const tma_ws_has_glu = gemm1_using_tma_ws && (mayHaveDifferentGEMMOutputType() || is_gated_activation); // We always use fused path if we can @@ -2826,7 +2953,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena int64_t const* const expert_first_token_offset, WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, QuantParams& quant_params, cudaStream_t stream) + int const num_experts_per_node, ActivationParams fc1_activation_type, QuantParams& quant_params, + cudaStream_t stream) { bool const is_gated_activation = isGatedActivation(fc1_activation_type); @@ -2914,7 +3042,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) { @@ -2931,7 +3059,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab bool const using_tma_ws_gemm1 = gemm_runner.isTmaWarpSpecialized(config); bool const is_gated_activation = isGatedActivation(fc1_activation_type); bool const use_ampere_activation_fusion - = gemm_runner.isFusedGatedActivation(config, is_gated_activation, inter_size, hidden_size); + = gemm_runner.isFusedGatedActivation(config, fc1_activation_type.activation_type, inter_size, hidden_size); size_t const fc1_out_size = ((!use_ampere_activation_fusion) && is_gated_activation) ? inter_size * 2 : inter_size; int64_t const* total_tokens_including_expert = expert_first_token_offset + 1; @@ -3064,7 +3192,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab : nullptr, fc1_expert_biases, static_cast<OutputType*>(use_ampere_activation_fusion ? output : intermediate_result), alpha_scale_ptr_array, /*occupancy*/ nullptr, - use_ampere_activation_fusion ? fc1_activation_type : ActivationType::Identity, expanded_num_rows, + use_ampere_activation_fusion ? fc1_activation_type.activation_type : ActivationType::Identity, + expanded_num_rows, /*N*/ int64_t(fc1_out_size), /*K*/ hidden_size, num_experts_per_node, quant_params.groupwise.group_size, bias_is_broadcast, use_ampere_activation_fusion, stream, config}; @@ -3076,8 +3205,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab { using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>; doGatedActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output), - static_cast<UnfusedGemmOutputType const*>(intermediate_result), num_valid_tokens_ptr, inter_size, - expanded_num_rows, fc1_activation_type, stream); + static_cast<UnfusedGemmOutputType const*>(intermediate_result), expert_first_token_offset, inter_size, + expanded_num_rows, num_experts_per_node, fc1_activation_type, stream); sync_check_cuda_error(stream); } @@ -3174,14 +3303,13 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab loraBiasApplyFunc(static_cast<UnfusedGemmOutputType*>(gemm_output), static_cast<UnfusedGemmOutputType const*>(gemm_output), nullptr, static_cast<ScaleBiasType const*>(fc2_lora), false, expert_first_token_offset, num_experts_per_node, - hidden_size, expanded_num_rows, ActivationType::Identity, {}, false, nullptr, stream); + hidden_size, expanded_num_rows, ActivationParams(ActivationType::Identity), {}, false, nullptr, stream); sync_check_cuda_error(stream); } bool has_different_output_type_ampere = (use_w4afp8 || use_fp8) && !using_tma_ws_gemm2; - bool using_hopper_fused_finalize - = tma_ws_input.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - bool has_different_output_type_tma_ws = !using_hopper_fused_finalize && using_tma_ws_gemm2; + bool using_fused_finalize = tma_ws_input.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; + bool has_different_output_type_tma_ws = !using_fused_finalize && using_tma_ws_gemm2; if (has_different_output_type_ampere || has_different_output_type_tma_ws) { @@ -3406,17 +3534,17 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab template <class T, class WeightType, class OutputType, class InputType, class BackBoneType, class Enable> void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::runMoe( - void const* input_activations_void, void const* input_sf_void, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void, - ActivationType fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_expert_biases_void, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const full_num_experts, int const experts_per_token, char* workspace_ptr, void* final_output_void, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) + void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void, + void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, void const* fc2_expert_weights_void, + void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const full_num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output_void, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) { static constexpr bool int_scales_required - = std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value; + = std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value || use_wfp4a16; static constexpr bool fp8_scales_required = std::is_same<WeightType, __nv_fp8_e4m3>::value || std::is_same<WeightType, __nv_fp8_e5m2>::value; @@ -3527,7 +3655,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab fc2_fp8_dequant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received quant scale for FC2"); } - bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales; + bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16; int const num_experts_per_node = full_num_experts / parallelism_config.ep_size; configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, @@ -3584,7 +3712,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab else { bool fused_prologue_result = false; - if (!use_w4afp8) + if (!use_w4_groupwise) { // WAR: fusedBuildExpertMapsSortFirstToken kernel will lead to illegal memory access for W4AFP8 fused_prologue_result = fusedBuildExpertMapsSortFirstToken(token_selected_experts, @@ -3625,7 +3753,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab expandInputRowsKernelLauncher(input_activations, gemm1_input_expand, token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token, num_experts_per_node, quant_params, use_per_expert_act_scale, expert_first_token_offset_, - fc1_fp4_act_scale_, input_sf, use_w4afp8 ? quant_params.groupwise.fc1.act_scales : nullptr, stream); + fc1_fp4_act_scale_, input_sf, swizzled_input_sf, + use_w4afp8 ? quant_params.groupwise.fc1.act_scales : nullptr, stream); auto const* gemm1_input = gemm1_input_expand; sync_check_cuda_error(stream); @@ -3698,7 +3827,8 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>:: float const* fp8_dequant2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream) + UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row, + cudaStream_t stream) { // Always nullptr layout_info1.ptr_c = nullptr; @@ -3706,6 +3836,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>:: layout_info2.ptr_c = nullptr; layout_info2.stride_c = nullptr; + layout_info1.fused_finalize_epilogue.ptr_bias = nullptr; + if (!bias2) + { + layout_info2.fused_finalize_epilogue.ptr_bias = nullptr; + } + auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale : use_fp8 ? fp8_dequant1 @@ -3720,8 +3856,10 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>:: layout_info2.alpha_scale_ptr_array = nullptr; } - layout_info1.int4_groupwise_params.enabled = use_w4afp8; - layout_info2.int4_groupwise_params.enabled = use_w4afp8; + layout_info1.int4_groupwise_params.enabled = use_w4_groupwise; + layout_info2.int4_groupwise_params.enabled = use_w4_groupwise; + layout_info1.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; + layout_info2.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; layout_info1.fpX_block_scaling_type = getScalingType(); layout_info2.fpX_block_scaling_type = getScalingType(); @@ -3744,7 +3882,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>:: cudaLaunchKernelEx(&config, kernel_instance, expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, gemm1_in, gemm2_in, weights1, weights2, alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, bias1, bias2, - gemm1_output, gemm2_output); + gemm1_output, gemm2_output, router_scales, permuted_row_to_unpermuted_row); return std::make_pair(layout_info1, layout_info2); } @@ -3762,7 +3900,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert, cudaStream_t stream) { - TLLM_CHECK_WITH_INFO(!use_w4afp8, "W4AFP8 is not supported in low latency mode"); + TLLM_CHECK_WITH_INFO(!use_w4_groupwise, "W4AFP8 and WFP4A16 are not supported in low latency mode"); // Always nullptr layout_info1.ptr_c = nullptr; @@ -3787,6 +3925,8 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, layout_info1.int4_groupwise_params.enabled = false; layout_info2.int4_groupwise_params.enabled = false; + layout_info1.int4_groupwise_params.use_wfp4a16 = false; + layout_info2.int4_groupwise_params.use_wfp4a16 = false; int const threads = std::min(1024, num_experts); int const blocks = (num_experts + threads - 1) / threads; @@ -3813,7 +3953,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, template <class T, class WeightType, class OutputType, class InputType, class BackBoneType, class Enable> std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::setupTmaWarpSpecializedInputs( - int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, int64_t hidden_size, + int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, int64_t inter_size, int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -3829,7 +3969,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>:: return std::make_pair(gemm1_tma_ws_input, gemm2_tma_ws_input); } - bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales; + bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16; bool is_gated_activation = isGatedActivation(fc1_activation_type); int64_t const fc1_out_size = is_gated_activation ? inter_size * 2 : inter_size; @@ -3865,15 +4005,15 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>:: gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; bool apply_bias = parallelism_config.tp_rank == 0; - bool using_hopper_fused_finalize - = !use_deterministic_hopper_reduce_ && gemm2_config_->sm_version == 90 && !use_w4afp8 && !use_lora; - if (using_hopper_fused_finalize) + auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr; + bool using_fused_finalize + = use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora; + if (using_fused_finalize) { assert(min_latency_mode == false); + bool use_reduction = expanded_num_rows > num_rows; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams(final_output, permuted_token_final_scales_, - expert_first_token_offset_, permuted_row_to_unpermuted_row_, apply_bias ? fc2_expert_biases : nullptr, - hidden_size, num_rows); + gemm2_tma_ws_input.setFinalizeFusionParams(final_output, hidden_size, num_rows, use_reduction); } // fp8_mxfp4 memsets the scaling factors to 1.0f @@ -3907,9 +4047,10 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>:: gemm2_tma_ws_input, num_rows, expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size, num_experts_per_node, reinterpret_cast<T const*>(gemm1_input), reinterpret_cast<T const*>(gemm2_input), fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2, - fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_expert_biases, + fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias, reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output), - reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_), stream); + reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_), permuted_token_final_scales_, + permuted_row_to_unpermuted_row_, stream); } } @@ -4136,6 +4277,8 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile bool is_fp4_w_quant = mWType == nvinfer1::DataType::kFP4 || mWType == nvinfer1::DataType::kINT64; bool is_w4afp8_quant = is_int_groupwise_w_quant && is_fp8_act_quant; // bool is_wfp4afp8_quant = is_fp4_w_quant && is_fp8_act_quant; + bool is_wfp4a16_quant = (mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) + && mWType == nvinfer1::DataType::kUINT8; // Int sizes size_t quant_1_size = is_int_w_quant ? fc1_out_size * num_experts_per_node * dtype_bytes : 0; @@ -4145,7 +4288,7 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile quant_1_size = fc1_out_size * num_experts_per_node * dtype_bytes; quant_2_size = hidden_size * num_experts_per_node * dtype_bytes; } - else if (is_int_groupwise_w_quant) + else if (is_int_groupwise_w_quant || is_wfp4a16_quant) { quant_1_size = fc1_out_size * num_experts_per_node * dtype_bytes * hidden_size / mGroupSize; quant_2_size = hidden_size * num_experts_per_node * dtype_bytes * inter_size / mGroupSize; @@ -4182,7 +4325,7 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile = TmaWarpSpecializedGroupedGemmInput::workspaceSize(num_experts_per_node, mScalingType) * (NUM_ROUTING_SAMPLES + 1); - if (is_w4afp8_quant) + if (is_w4afp8_quant || is_wfp4a16_quant) { quant_3_size = 0; quant_4_size = 0; @@ -4201,7 +4344,7 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile * sizeof(TmaWarpSpecializedGroupedGemmInput::ElementSF); size_t const fp4_act_scale_flat_size = std::max(fc1_fp4_act_scale_size, fc2_fp4_act_scale_size); - size_t w4a8_alpha_size = is_w4afp8_quant ? num_experts_per_node * sizeof(float) : 0; + size_t w4a8_alpha_size = (is_w4afp8_quant || is_wfp4a16_quant) ? num_experts_per_node * sizeof(float) : 0; size_t alpha_scale_ptr_array_size = num_experts_per_node * sizeof(float**); size_t gemm_workspace_size = mInterface->getGemmWorkspaceSize(num_experts_per_node); @@ -4224,6 +4367,11 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile = mMinLatencyMode ? sizeof(int) * NUM_ROUTING_SAMPLES : 0; // smaller than or equal to num_experts_per_node size_t active_expert_global_ids_size = mMinLatencyMode ? mNumExpertsPerNode * sizeof(int) * NUM_ROUTING_SAMPLES : 0; + bool is_swiglu_bias = mActivationType == ActivationType::SwigluBias && mGemmToProfile == GemmToProfile::GEMM_1; + size_t swiglu_alpha_size = is_swiglu_bias ? num_experts_per_node * sizeof(float) : 0; + size_t swiglu_beta_size = is_swiglu_bias ? num_experts_per_node * sizeof(float) : 0; + size_t swiglu_limit_size = is_swiglu_bias ? num_experts_per_node * sizeof(float) : 0; + size_t map_offset = 0; std::map<std::string, std::pair<size_t, size_t>> out_map; @@ -4263,7 +4411,9 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile ADD(alpha_scale_ptr_array); ADD(fp4_act_scale_flat); ADD(gemm_workspace); - + ADD(swiglu_alpha); + ADD(swiglu_beta); + ADD(swiglu_limit); #undef ADD_NAME #undef ADD @@ -4348,15 +4498,19 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr GET_WS_PTR(float const*, w4a8_alpha); #undef GET_WS_PTR - if ((mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4) && mGroupSize < 0) + if ((mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 + || mWType == nvinfer1::DataType::kUINT8) + && mGroupSize < 0) { TLLM_CHECK(quant_1 && quant_2); mQuantParams = QuantParams::Int(quant_1, quant_2); } - else if (mWType == nvinfer1::DataType::kINT4) + else if (mWType == nvinfer1::DataType::kINT4 || mWType == nvinfer1::DataType::kUINT8) { TLLM_CHECK(quant_1 && quant_2); - if (mDType == nvinfer1::DataType::kFP8) + if (mDType == nvinfer1::DataType::kFP8 + || (mWType == nvinfer1::DataType::kUINT8 + && (mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16))) { TLLM_CHECK(w4a8_alpha); mQuantParams = QuantParams::GroupWise( @@ -4457,17 +4611,17 @@ void GemmProfilerBackend::prepareTmaWsInputs( gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; - bool apply_bias = true; bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); + bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) + && mWType == nvinfer1::DataType::kUINT8); + bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; bool using_fused_finalize - = !mInterface->use_deterministic_hopper_reduce_ && mSM == 90 && !mMinLatencyMode && !use_w4afp8; + = mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise; if (using_fused_finalize) { assert(!mMinLatencyMode); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams(output, token_topk_unpermuted_scales, - expert_first_token_offset, permuted_row_to_unpermuted_row, apply_bias ? bias : nullptr, - mExpertHiddenSize, num_tokens); + gemm2_tma_ws_input.setFinalizeFusionParams(output, mExpertHiddenSize, num_tokens, mK > 1); } auto fc1_output_size = isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize; @@ -4488,7 +4642,7 @@ void GemmProfilerBackend::prepareTmaWsInputs( fc1_output_size, mExpertHiddenSize, mExpertHiddenSize, mExpertInterSize, mNumExpertsPerNode, input, input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate, - stream); + token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, stream); } sync_check_cuda_error(stream); } @@ -4560,6 +4714,10 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac GET_WS_PTR(TmaWarpSpecializedGroupedGemmInput::ElementSF*, fp4_act_scale_flat); GET_WS_PTR(void*, gemm_workspace); + GET_WS_PTR(float*, swiglu_alpha); + GET_WS_PTR(float*, swiglu_beta); + GET_WS_PTR(float*, swiglu_limit); + #undef GET_WS_PTR_OFFSET #undef GET_WS_PTR @@ -4592,7 +4750,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac mExpertHiddenSize, // mExpertInterSize, // num_experts_per_node, // - mActivationType, // + ActivationParams(mActivationType, swiglu_alpha, swiglu_beta, swiglu_limit), alpha_scale_ptr_array, // !mUseLora, // /*use_deepseek_fp8_block_scale=*/false, // @@ -4672,11 +4830,13 @@ template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half, half>; +template class CutlassMoeFCRunner<half, __nv_fp4_e2m1>; #ifdef ENABLE_BF16 template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>; #endif #endif diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index bc16af1d26..9f2296dd7d 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -47,8 +47,7 @@ constexpr bool isValidSM120MOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value && cutlass::platform::is_same<T, WeightType>::value - && cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value - && Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + && cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -63,8 +62,7 @@ constexpr bool isValidBlackwellMOESpecialisation() return (cutlass::platform::is_same<T, WeightType>::value || (cutlass::platform::is_same<T, __nv_fp8_e4m3>::value && cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) - && cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value - && Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + && cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -79,7 +77,9 @@ constexpr bool isValidHopperMOESpecialisation() #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) return (cutlass::platform::is_same<T, WeightType>::value || (cutlass::platform::is_same<cutlass::uint4b_t, WeightType>::value - && cutlass::platform::is_same<T, __nv_fp8_e4m3>::value)) + && cutlass::platform::is_same<T, __nv_fp8_e4m3>::value) + || (cutlass::platform::is_same<__nv_fp4_e2m1, WeightType>::value + && !cutlass::platform::is_same<T, __nv_fp8_e4m3>::value)) #ifdef ENABLE_FP4 && !cutlass::platform::is_same<T, __nv_fp4_e2m1>::value #endif diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py index bdb8af652d..05e4bd33e9 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py @@ -1,7 +1,7 @@ import argparse import enum import os -from itertools import product +from itertools import chain, product from cutlass_library import * @@ -111,7 +111,10 @@ CudaTypeName = { DataType.e4m3: "__nv_fp8_e4m3", DataType.bf16: "__nv_bfloat16", DataType.f16: "half", - DataType.f32: "float" + DataType.f32: "float", + DataType.e2m1: "__nv_fp4_e2m1", + DataType.ue8m0: "cutlass::float_ue8m0_t", + DataType.u4: "cutlass::uint4b_t" } @@ -209,14 +212,13 @@ template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {s {kernel_sched}, {epi_sched}> ( const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, {out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* -); -""" +);""" elif operation.gemm_kind == GemmKind.Grouped: if operation.act_type != operation.weight_type and ( operation.act_type != DataType.e4m3 or operation.weight_type != e2m1): # Mixed MoE GEMM - weight_tag = DataTypeTag[operation.weight_type] + weight_tag = CudaTypeName[operation.weight_type] instantiation = f""" template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag}, {epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> ( @@ -258,11 +260,9 @@ GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSp # (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*); # """ instantiation = f""" -#if {guard_act} && {guard_weight}\n - INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, - {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n -#endif -""" +#if {guard_act} && {guard_weight} + INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false); +#endif""" return instantiation @@ -273,8 +273,7 @@ def instantiate_operation_sm80(operation): instantiation = f""" template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}> - ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); - """ + ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);""" return instantiation @@ -337,6 +336,10 @@ def write_file(launcher_inl_files, operations, output_file): f.write(content) +def elementwise(x, y, f): + return tuple(f(a, b) for (a, b) in zip(x, y)) + + def is_gemm_op_valid_sm100(op): # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future tile_m, tile_n, _ = op.cta_shape @@ -361,10 +364,11 @@ def is_gemm_op_valid_sm100(op): return False # Shapes for fp8 small N shapes - if (op.act_type == DataType.e4m3 and (tile_n == 16 or tile_n == 8) - and (cga_m == 1 and cga_n == 1)): - # todo: double check why this is disable in CUTLASS backend. @yuhan - return not (tile_m == 128 and tile_n % 16 != 0) + if (op.act_type == DataType.e4m3) and (tile_n == 16 + or tile_n == 8) and (cga_m == 1 + and cga_n == 1): + # todo: double check why tile_n = 8 is disabled in CUTLASS backend. @yuhan + return tile_m != 128 or tile_n % 16 == 0 # Default alignment requirements if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256: @@ -537,11 +541,19 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled): if not is_arch_enabled: return [] arch = 90 - supported_dtypes = [ + + # act_type, weight_type, scalezero_type, bias_type, output_type + supported_dtypes_int4 = [ (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), (DataType.e4m3, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16), ] + supported_dtypes_fp4 = [ + (DataType.f16, DataType.e2m1, DataType.ue8m0, DataType.f16, + DataType.f16), + (DataType.bf16, DataType.e2m1, DataType.ue8m0, DataType.bf16, + DataType.bf16), + ] quant_ops = [TrtLlm_QuantOp.finegrained_scale_only] @@ -550,15 +562,24 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled): M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM N_TILES = [16, 32, 64, 128] K_TILES = [128, 256, 512] - cta_shapes_mnk = list(product(M_TILES, N_TILES, K_TILES)) + cta_shapes_mnk_int4 = list(product(M_TILES, N_TILES, K_TILES)) + + M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM + N_TILES = [16, 32, 64] + K_TILES = [128, 256] + cta_shapes_mnk_fp4 = list(product(M_TILES, N_TILES, K_TILES)) + cta_shapes_mnk_fp4.append((128, 128, 128)) warp_shape = [0, 0, 0] # ignored except for naming stages = 0 # auto - cga_shapes = product([1, 2], [1, 2], [1]) + cga_shapes = list(product([1, 2], [1, 2], [1])) - partial_args = product(supported_dtypes, quant_ops, epi_tags, - cta_shapes_mnk, cga_shapes) + partial_args_int4 = product(supported_dtypes_int4, quant_ops, epi_tags, + cta_shapes_mnk_int4, cga_shapes) + partial_args_fp4 = product(supported_dtypes_fp4, quant_ops, epi_tags, + cta_shapes_mnk_fp4, cga_shapes) + partial_args = chain(partial_args_int4, partial_args_fp4) operations = list() for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args: @@ -592,8 +613,6 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype): cta_shape_k = max_k_bits // GetDataTypeBits(dtype) if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8): cta_shape_k = 256 - if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16): - cta_shape_k = 128 return cta_shape_mn + (cta_shape_k, ) @@ -613,7 +632,7 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize ] cga_shapes = [[1, 1, 1]] @@ -665,7 +684,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled, arch): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize ] cga_shapes = list(product([1, 2], [1, 2], [1])) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h index 4540a39ebf..3f2705f2ee 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h @@ -129,6 +129,8 @@ struct Multihead_attention_params_base float rotary_embedding_scale = 1.0f; // The pre-computed rotary inv freq when building the engines (as constant weights). float const* rotary_embedding_inv_freq_cache = nullptr; + // The pre-computed cos/sin cache. + float2 const* rotary_embedding_cos_sin_cache = nullptr; float rotary_embedding_short_m_scale = 1.0f; float rotary_embedding_long_m_scale = 1.0f; int rotary_embedding_max_positions = 0; @@ -152,6 +154,9 @@ struct Multihead_attention_params_base bool const* attention_mask = nullptr; int attention_mask_stride = 0; + // The attention sinks [num_heads_q]. + float const* attention_sinks = nullptr; + // If relative position embedding is used T const* relative_attention_bias = nullptr; int relative_attention_bias_stride = 0; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index 8536b940a7..2403504a90 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -1363,8 +1363,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske #ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS if (sizeof(Tk) != 4) { - auto const max_timesteps - = min(timestep, min(static_cast<unsigned>(cyclic_kv_cache_len), chunked_attention_size)); + auto const max_timesteps = min(timestep, static_cast<unsigned>(cyclic_kv_cache_len)); logits_smem_ += divUp(max_timesteps + 1, 4u) * 16; } Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_); @@ -1718,6 +1717,13 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske float rotary_embedding_m_scale = tlength <= params.rotary_embedding_original_max_positions ? params.rotary_embedding_short_m_scale : params.rotary_embedding_long_m_scale; + // The rotary cos_sin cache for the current timestep + float2 const* cos_sin_cache = params.rotary_embedding_cos_sin_cache; + if (cos_sin_cache) + { + cos_sin_cache += (static_cast<int64_t>(position_idx) * params.rotary_embedding_dim / 2); + } + mmha::vec_from_smem_transpose(q, q_smem_, transpose_idx, smem_pitch); if (HANDLE_KV) { @@ -1725,7 +1731,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, position_idx, rotary_embedding_inv_freq_cache, - rotary_embedding_m_scale, params.rotary_cogvlm_vision_start, params.rotary_cogvlm_vision_length); + cos_sin_cache, rotary_embedding_m_scale, params.rotary_cogvlm_vision_start, + params.rotary_cogvlm_vision_length); mmha::write_smem_transpose(k, k_smem_, transpose_idx, smem_pitch); } @@ -1733,7 +1740,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske { mmha::apply_rotary_embedding(q, transpose_idx / tidx_factor, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, position_idx, rotary_embedding_inv_freq_cache, - rotary_embedding_m_scale, params.rotary_cogvlm_vision_start, params.rotary_cogvlm_vision_length); + cos_sin_cache, rotary_embedding_m_scale, params.rotary_cogvlm_vision_start, + params.rotary_cogvlm_vision_length); } mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch); } @@ -2243,6 +2251,13 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Compute the sum. sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum); + // Add the attention sinks. + // It has been moved to the end of the kernel if the multi-block mode is enabled. + if (!MULTI_BLOCK_FLAG && params.attention_sinks != nullptr) + { + sum += expf(params.attention_sinks[hi] - qk_max); + } + // Normalize the logits. #ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V float logit_scale = (FP8_KV_CACHE ? kv_scale_quant_orig_f : 1.0f); @@ -2693,6 +2708,12 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske if (oo == 0 && (Dh == Dh_MAX || oi < Dh)) { + // Add the attention sinks. + if (params.attention_sinks != nullptr) + { + final_sum += expf(params.attention_sinks[hi] - final_max); + } + auto const inv_sum = __fdividef( write_attention_quant ? *params.attention_out_scale_orig_quant : 1.f, final_sum + 1.e-6f); diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index 3da27ff38c..ac331ac33f 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -84,12 +84,6 @@ bool DecoderXQAImplJIT::mayHavePerfGain(XQAParams const& xqaParams) const // Always use at least 1 block regardless of history length multi_block_count = std::max(1, history_length / kMinHistoryTokensPerBlock); } - // Disable XQA for sliding window when cyclic_attention_window_size <= 256. - if (xqaParams.max_past_kv_length + 1 > xqaParams.cyclic_attention_window_size - && xqaParams.cyclic_attention_window_size <= 256) - { - return false; - } int block_count = num_kv_heads * batch_size * multi_block_count; return static_cast<float>(block_count) * kEnableMinBlockFactor >= static_cast<float>(mRunner->mMultiProcessorCount); } @@ -418,6 +412,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& { appendParam(&launchParams.ropeCosSin); } + appendParam(&xqaParams.attention_sinks); appendParam(&launchParams.kvCacheParams); if (xqaParams.beam_width > 1) { diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h index 4c1ab13f05..97ad58335f 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h @@ -34,6 +34,7 @@ struct XQAParams void* output_sf = nullptr; void const* qkv = nullptr; int32_t const* cache_indir = nullptr; + float const* attention_sinks = nullptr; float const* kv_scale_orig_quant = nullptr; float const* kv_scale_quant_orig = nullptr; int32_t const* host_past_key_value_lengths = nullptr; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h index 5f767a2504..09bd551c0b 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h @@ -2609,8 +2609,8 @@ inline __device__ void update_rotary_base_n_scale(float& base, float& scale, Rot } } -inline __device__ float2 rotary_embedding_coefficient(float const* inv_freq_cache, int const zid, - int const rot_embed_dim, float const base, float const scale, float const mscale, float const t_step, +inline __device__ float2 rotary_embedding_coefficient(float const* inv_freq_cache, float2 const* cos_sin_cache, + int const zid, int const rot_embed_dim, float const base, float const scale, float const mscale, float const t_step, int const vision_start = -1, int const vision_length = -1) { float real_step = t_step; @@ -2631,7 +2631,12 @@ inline __device__ float2 rotary_embedding_coefficient(float const* inv_freq_cach } } // Load from global memory cache if it is not nullptr. - if (inv_freq_cache) + if (cos_sin_cache) + { + float2 const cos_sin = cos_sin_cache[zid / 2]; + return {cos_sin.x, cos_sin.y}; + } + else if (inv_freq_cache) { float const inv_freq = float(real_step) * inv_freq_cache[zid / 2]; return {cosf(inv_freq) * mscale, sinf(inv_freq) * mscale}; @@ -2668,48 +2673,49 @@ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 #endif inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, float base, float scale, int t_step, - float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, int vision_length = -1) + float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { return; } inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { return; } inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (2 * tid >= rot_embed_dim) { return; } - auto const coef = rotary_embedding_coefficient( - inv_freq_cache, 2 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 2 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q = rotary_embedding_transform(q, coef); } inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (2 * tid >= rot_embed_dim) { return; } - auto const coef = rotary_embedding_coefficient( - inv_freq_cache, 2 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 2 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (4 * tid >= rot_embed_dim) { @@ -2717,17 +2723,17 @@ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_ } Float4_& q_ = *reinterpret_cast<Float4_*>(&q); - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q_.x = rotary_embedding_transform(q_.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q_.y = rotary_embedding_transform(q_.y, coef1); } inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (4 * tid >= rot_embed_dim) { @@ -2736,19 +2742,19 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int Float4_& q_ = *reinterpret_cast<Float4_*>(&q); Float4_& k_ = *reinterpret_cast<Float4_*>(&k); - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q_.x = rotary_embedding_transform(q_.x, coef0); k_.x = rotary_embedding_transform(k_.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q_.y = rotary_embedding_transform(q_.y, coef1); k_.y = rotary_embedding_transform(k_.y, coef1); } inline __device__ void apply_rotary_embedding(Float8_& q, Float8_& k, int tid, int rot_embed_dim, float base, - float scale, int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + float scale, int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, + float mscale = 1.0f, int vision_start = -1, int vision_length = -1) { if (8 * tid >= rot_embed_dim) { @@ -2757,93 +2763,95 @@ inline __device__ void apply_rotary_embedding(Float8_& q, Float8_& k, int tid, i Float8_& q_ = *reinterpret_cast<Float8_*>(&q); Float8_& k_ = *reinterpret_cast<Float8_*>(&k); - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q_.x = rotary_embedding_transform(q_.x, coef0); k_.x = rotary_embedding_transform(k_.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q_.y = rotary_embedding_transform(q_.y, coef1); k_.y = rotary_embedding_transform(k_.y, coef1); - auto const coef2 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 4, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef2 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 4, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q_.z = rotary_embedding_transform(q_.z, coef2); k_.z = rotary_embedding_transform(k_.z, coef2); - auto const coef3 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 6, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef3 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 6, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q_.w = rotary_embedding_transform(q_.w, coef3); k_.w = rotary_embedding_transform(k_.w, coef3); } inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (2 * tid >= rot_embed_dim) { return; } - auto const coef = rotary_embedding_coefficient( - inv_freq_cache, 2 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 2 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q = rotary_embedding_transform(q, coef); } inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, float base, - float scale, int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + float scale, int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, + float mscale = 1.0f, int vision_start = -1, int vision_length = -1) { if (2 * tid >= rot_embed_dim) { return; } - auto const coef = rotary_embedding_coefficient( - inv_freq_cache, 2 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 2 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } inline __device__ void apply_rotary_embedding(half2& q, int tid, int rot_embed_dim, float base, float scale, int t_step, - float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, int vision_length = -1) + float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { return apply_rotary_embedding(*reinterpret_cast<uint32_t*>(&q), tid, rot_embed_dim, base, scale, mscale, - inv_freq_cache, t_step, vision_start, vision_length); + inv_freq_cache, cos_sin_cache, t_step, vision_start, vision_length); } inline __device__ void apply_rotary_embedding(half2& q, half2& k, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { return apply_rotary_embedding(*reinterpret_cast<uint32_t*>(&q), *reinterpret_cast<uint32_t*>(&k), tid, - rot_embed_dim, base, scale, mscale, inv_freq_cache, t_step, vision_start, vision_length); + rot_embed_dim, base, scale, mscale, inv_freq_cache, cos_sin_cache, t_step, vision_start, vision_length); } inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, float base, float scale, int t_step, - float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, int vision_length = -1) + float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (4 * tid >= rot_embed_dim) { return; } - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.y = rotary_embedding_transform(q.y, coef1); } inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (4 * tid >= rot_embed_dim) { return; } - float2 coef0 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); - float2 coef1 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + float2 coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); + float2 coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid + 2, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); @@ -2852,108 +2860,109 @@ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int r } inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, float base, float scale, int t_step, - float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, int vision_length = -1) + float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (8 * tid >= rot_embed_dim) { return; } - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.y = rotary_embedding_transform(q.y, coef1); - auto const coef2 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 4, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef2 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 4, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.z = rotary_embedding_transform(q.z, coef2); - auto const coef3 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 6, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef3 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 6, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.w = rotary_embedding_transform(q.w, coef3); } inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (8 * tid >= rot_embed_dim) { return; } - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); - auto const coef2 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 4, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef2 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 4, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.z = rotary_embedding_transform(q.z, coef2); k.z = rotary_embedding_transform(k.z, coef2); - auto const coef3 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 6, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef3 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 6, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.w = rotary_embedding_transform(q.w, coef3); k.w = rotary_embedding_transform(k.w, coef3); } #ifdef ENABLE_BF16 inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) -{ - if (2 * tid >= rot_embed_dim) - { - return; - } - auto const coef = rotary_embedding_coefficient( - inv_freq_cache, 2 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, - float base, float scale, int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, int vision_start = -1, int vision_length = -1) { if (2 * tid >= rot_embed_dim) { return; } - auto const coef = rotary_embedding_coefficient( - inv_freq_cache, 2 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 2 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, + float base, float scale, int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, + float mscale = 1.0f, int vision_start = -1, int vision_length = -1) +{ + if (2 * tid >= rot_embed_dim) + { + return; + } + auto const coef = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 2 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (4 * tid >= rot_embed_dim) { return; } - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.y = rotary_embedding_transform(q.y, coef1); } inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, float base, - float scale, int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + float scale, int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, + float mscale = 1.0f, int vision_start = -1, int vision_length = -1) { if (4 * tid >= rot_embed_dim) { return; } - float2 coef0 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); - float2 coef1 = rotary_embedding_coefficient( - inv_freq_cache, 4 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + float2 coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); + float2 coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 4 * tid + 2, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); q.y = rotary_embedding_transform(q.y, coef1); @@ -2961,49 +2970,49 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, } inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, float base, float scale, - int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, float mscale = 1.0f, + int vision_start = -1, int vision_length = -1) { if (8 * tid >= rot_embed_dim) { return; } - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.y = rotary_embedding_transform(q.y, coef1); - auto const coef2 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 4, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef2 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 4, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.z = rotary_embedding_transform(q.z, coef2); - auto const coef3 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 6, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef3 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 6, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.w = rotary_embedding_transform(q.w, coef3); } inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, float base, - float scale, int t_step, float const* inv_freq_cache = nullptr, float mscale = 1.0f, int vision_start = -1, - int vision_length = -1) + float scale, int t_step, float const* inv_freq_cache = nullptr, float2 const* cos_sin_cache = nullptr, + float mscale = 1.0f, int vision_start = -1, int vision_length = -1) { if (8 * tid >= rot_embed_dim) { return; } - auto const coef0 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef0 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid, rot_embed_dim, base, scale, + mscale, t_step, vision_start, vision_length); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - auto const coef1 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 2, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef1 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 2, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); - auto const coef2 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 4, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef2 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 4, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.z = rotary_embedding_transform(q.z, coef2); k.z = rotary_embedding_transform(k.z, coef2); - auto const coef3 = rotary_embedding_coefficient( - inv_freq_cache, 8 * tid + 6, rot_embed_dim, base, scale, mscale, t_step, vision_start, vision_length); + auto const coef3 = rotary_embedding_coefficient(inv_freq_cache, cos_sin_cache, 8 * tid + 6, rot_embed_dim, base, + scale, mscale, t_step, vision_start, vision_length); q.w = rotary_embedding_transform(q.w, coef3); k.w = rotary_embedding_transform(k.w, coef3); } diff --git a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp index 068dfb026a..82336a961a 100644 --- a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp +++ b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp @@ -165,6 +165,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams) tllmRunnerParams.vPtr = nullptr; tllmRunnerParams.kvPtr = kvPoolPtr; tllmRunnerParams.qkvPtr = runnerParams.qkvPtr; + tllmRunnerParams.attentionSinksPtr = runnerParams.attentionSinksPtr; tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr); tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr); tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(runnerParams.scaleBmm2Ptr); diff --git a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/fp4_converter.cuh b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/fp4_converter.cuh index 6a480dbc05..13de943b43 100644 --- a/cpp/tensorrt_llm/kernels/fusedLayernormKernels/fp4_converter.cuh +++ b/cpp/tensorrt_llm/kernels/fusedLayernormKernels/fp4_converter.cuh @@ -120,9 +120,8 @@ struct FP4Converter<TIn, UE8M0_SF, std::enable_if_t<std::is_same_v<TIn, half> || SFValue = static_cast<float>(tmp); } - auto SFOffset - = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF, SF_VEC_SIZE>(std::nullopt /* batchIdx */, - rowIdx, colIdx, std::nullopt /* numRows */, numCols, SFout, FP4QuantizationSFLayout::SWIZZLED); + auto SFOffset = cvt_quant_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF>(std::nullopt /* batchIdx */, rowIdx, + colIdx, std::nullopt /* numRows */, numCols / SF_VEC_SIZE, SFout, QuantizationSFLayout::SWIZZLED); *SFOffset = fp8SFVal; // Get the output scale. // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) @@ -234,9 +233,8 @@ struct FP4Converter<float, UE8M0_SF> SFValue = static_cast<float>(tmp); } - auto SFOffset - = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF, SF_VEC_SIZE>(std::nullopt /* batchIdx */, - rowIdx, colIdx, std::nullopt /* numRows */, numCols, SFout, FP4QuantizationSFLayout::SWIZZLED); + auto SFOffset = cvt_quant_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF>(std::nullopt /* batchIdx */, rowIdx, + colIdx, std::nullopt /* numRows */, numCols / SF_VEC_SIZE, SFout, QuantizationSFLayout::SWIZZLED); float outputScale = reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)); // Convert the input to float. diff --git a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu index 9ce057ee7b..93a838b09f 100644 --- a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu +++ b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu @@ -65,7 +65,12 @@ __global__ void fusedQKNormRopeKernel( __nv_bfloat16 const* k_weight, // RMSNorm weights for key float const base, // Base for RoPE computation int const* position_ids, // Position IDs for RoPE - int const num_tokens // Number of tokens + int const num_tokens, // Number of tokens + // parameters for yarn + float factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn. + float low, // threshold for high frequency + float high, // threshold for low frequency + float attention_factor // attention_factor applied on cos and sin ) { int const warpsPerBlock = blockDim.x / 32; @@ -170,6 +175,25 @@ __global__ void fusedQKNormRopeKernel( int dim_idx = laneId * numElemsPerThread + i; int half_dim = dim_idx / 2; float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim)); + + if (factor != 1.0f) + { + float inv_freq_extrapolation = freq; + float inv_freq_interpolation = freq / factor; + + // linear_ramp_factor + if (fabsf(low - high) <= 1e-6f) + { + high += 0.001; // Prevent singularity + } + float linear_func = (static_cast<float>(half_dim) - low) / (high - low); + // clamp linear_func to [0.0f, 1.0f] + float ramp_func = fmin(fmax(linear_func, 0.0f), 1.0f); + float inv_freq_extrapolation_factor = 1.0f - ramp_func; + freq = inv_freq_interpolation * (1.0f - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor; + } + float theta = pos_id * freq; __sincosf(theta, &sin_vals[i], &cos_vals[i]); } @@ -191,6 +215,25 @@ __global__ void fusedQKNormRopeKernel( dim_idx = (dim_idx * 2) % head_dim; int half_dim = dim_idx / 2; float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim)); + + if (factor != 1.0f) + { + float inv_freq_extrapolation = freq; + float inv_freq_interpolation = freq / factor; + + // linear_ramp_factor + if (fabsf(low - high) <= 1e-6f) + { + high += 0.001; // Prevent singularity + } + float linear_func = (static_cast<float>(half_dim) - low) / (high - low); + // clamp linear_func to [0.0f, 1.0f] + float ramp_func = fmin(fmax(linear_func, 0.0f), 1.0f); + float inv_freq_extrapolation_factor = 1.0f - ramp_func; + freq = inv_freq_interpolation * (1.0f - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor; + } + float theta = pos_id * freq; __sincosf(theta, &sin_vals[i], &cos_vals[i]); } @@ -200,7 +243,7 @@ __global__ void fusedQKNormRopeKernel( for (int i = 0; i < numElemsPerThread; i++) { - elements[i] = elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]; + elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; } // Store. @@ -232,8 +275,13 @@ __global__ void fusedQKNormRopeKernel( void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k, int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight, - float const base, bool const interleave, int const* position_ids, cudaStream_t stream) + float const base, bool const interleave, int const* position_ids, float factor, float low, float high, + float attention_factor, cudaStream_t stream) { + if (factor == 1.0f) + { + TLLM_CHECK(attention_factor == 1.0f); + } constexpr int blockSize = 256; int const warpsPerBlock = blockSize / 32; @@ -250,18 +298,18 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_ { case 64: DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - fusedQKNormRopeKernel<64, INTERLEAVE> - <<<gridDim, blockDim, 0, stream>>>(reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, - num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), - reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens); + fusedQKNormRopeKernel<64, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>( + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps, + reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight), + base, position_ids, num_tokens, factor, low, high, attention_factor); }); break; case 128: DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - fusedQKNormRopeKernel<128, INTERLEAVE> - <<<gridDim, blockDim, 0, stream>>>(reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, - num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), - reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens); + fusedQKNormRopeKernel<128, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>( + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps, + reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight), + base, position_ids, num_tokens, factor, low, high, attention_factor); }); break; default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim); diff --git a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h index 09146a8d03..041efd8a14 100644 --- a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h +++ b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h @@ -38,7 +38,11 @@ void launchFusedQKNormRope( float const base, // Base for RoPE computation bool const interleave, // Whether RoPE is applied in interleave mode (non-Neox style) int const* position_ids, // Position IDs for RoPE [num_tokens] - cudaStream_t stream); // CUDA stream + float factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn. + float low, // threshold for high frequency + float high, // threshold for low frequency + float attention_factor, // attention_factor applied on cos and sin + cudaStream_t stream); // CUDA stream } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 1a800b30dc..08cd9b6f66 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6d12357919fe6c63749a81e124afd60453153489a3f50cb44b41671d9b55f947 -size 64338696 +oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf +size 67051604 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index 62c9a58c08..8b500f5c97 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -ad34c0f31247c880d60e2c8198093e8373cf0e1d3e8badee0424bfa607d6cd8e libtensorrt_llm_internal_cutlass_kernels_static.a -commit bac309ac608d35d7d0144e594bf3e5fa8cfca796 +568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a +commit 9c0a42825905952beaf9b35d5a35d58de1a123fa diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h index a52f4b7aaf..b7d1340a2e 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h @@ -39,11 +39,6 @@ namespace tensorrt_llm { -template <class T> -constexpr auto transpose_stride(T const& t) -{ - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t)); -} // Note update moe.py to match enum class ActivationType @@ -53,6 +48,7 @@ enum class ActivationType Silu, Swiglu, Geglu, + SwigluBias, Identity, InvalidType }; @@ -86,8 +82,6 @@ struct GroupedGemmInput struct TmaWarpSpecializedGroupedGemmInput { - template <class T> - using TransposeStride = decltype(transpose_stride<T>(T{})); template <class Tag> using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; @@ -100,6 +94,7 @@ struct TmaWarpSpecializedGroupedGemmInput using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand + using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand constexpr static int NVFP4BlockScaleVectorSize = 16; constexpr static int MXFPXBlockScaleVectorSize = 32; @@ -121,6 +116,7 @@ struct TmaWarpSpecializedGroupedGemmInput using StrideB = std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>; + using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>; #ifdef ENABLE_FP8 template <class T> @@ -147,37 +143,26 @@ struct TmaWarpSpecializedGroupedGemmInput StrideC* stride_c = nullptr; void const** ptr_c = nullptr; - struct DefaultEpilogue - { - using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; + // D is used in all cases except fused finalize + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; struct FusedFinalizeEpilogue { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>; - using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>; + using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>; void* ptr_final_output = nullptr; StrideFinalOutput stride_final_output{}; - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; + void const** ptr_bias = nullptr; + float const** ptr_router_scales = nullptr; - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; + int const** ptr_source_token_index = nullptr; + int num_rows_in_final_output = 0; - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; - - size_t num_rows_in_final_output = 0; + bool use_reduction = true; }; - DefaultEpilogue default_epilogue; FusedFinalizeEpilogue fused_finalize_epilogue; enum class EpilogueFusion @@ -210,7 +195,8 @@ struct TmaWarpSpecializedGroupedGemmInput struct INT4GroupwiseParams { - constexpr static int group_size = 128; // Unused, hard-coded to 128 + constexpr static int int4_group_size = 128; + constexpr static int wfp4a16_group_size = 32; bool enabled = false; using SFA = __nv_bfloat16; using SFB = __nv_bfloat16; // Unused @@ -233,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; - static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); + static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); @@ -245,16 +231,15 @@ struct TmaWarpSpecializedGroupedGemmInput return stride_a != nullptr && ptr_a != nullptr; } - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); + void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction); std::string toString() const; }; constexpr bool isGatedActivation(ActivationType activation_type) { - return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu + || activation_type == ActivationType::SwigluBias; } template <typename T, /*The type used for activations/scales/compute*/ @@ -305,9 +290,9 @@ public: [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; [[nodiscard]] bool supportsTmaWarpSpecialized() const; - [[nodiscard]] bool isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; - [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; + [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, + ActivationType activation_type, int gemm_n, int gemm_k) const; + [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; size_t getMaxWorkspaceSize(int num_experts) const; diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h index f7d1705709..1bda2247ce 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h @@ -62,6 +62,36 @@ private: int num_bits_; }; +struct ActivationParams +{ + ActivationType activation_type; + float const* swiglu_alpha = nullptr; + float const* swiglu_beta = nullptr; + float const* swiglu_limit = nullptr; + + explicit ActivationParams(ActivationType activation_type) + : activation_type(activation_type) + { + TLLM_CHECK_WITH_INFO(activation_type != ActivationType::SwigluBias, + "SwigluBias is not supported in ActivationParams without swiglu_alpha and swiglu_beta"); + } + + ActivationParams( + ActivationType activation_type, float const* swiglu_alpha, float const* swiglu_beta, float const* swiglu_limit) + : activation_type(activation_type) + , swiglu_alpha(swiglu_alpha) + , swiglu_beta(swiglu_beta) + , swiglu_limit(swiglu_limit) + { + } + + // TODO Port everything properly and get rid of these implicit conversions + operator ActivationType() const + { + return activation_type; + } +}; + /** * \brief Describes what parallelism mode the MoE is using * @@ -175,9 +205,9 @@ struct QuantParams { bool use_per_expert_act_scale = false; float const* act_global_scale - = nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale + = nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale - nullptr for fc1 TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale - = nullptr; // (experts, n, k / 32) + = nullptr; // (experts, n, k / 32) float const* global_scale = nullptr; // (num_experts_per_node, ) }; @@ -393,10 +423,10 @@ public: virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, + ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, bool use_lora, + int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) = 0; @@ -409,7 +439,7 @@ public: float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert) @@ -420,8 +450,8 @@ public: void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales, float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, @@ -438,7 +468,8 @@ public: void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) = 0; virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> @@ -455,13 +486,13 @@ public: virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; - bool use_deterministic_hopper_reduce_ = false; + bool use_fused_finalize_ = true; }; // Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . // Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. // Avoid making several duplicates of this class. -template <typename T, /*The type used for activations*/ +template <typename T, /* The type used for activations */ typename WeightType, /* The type for the MoE weights */ typename OutputType = T, /* The type for the MoE final output */ typename InputType = T, /* The type for the MoE input */ @@ -540,10 +571,10 @@ public: void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, + ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, bool use_lora, + int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; @@ -562,7 +593,7 @@ public: TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert); @@ -573,8 +604,8 @@ public: ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales, float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, @@ -589,7 +620,7 @@ public: float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override @@ -609,8 +640,8 @@ public: void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales, float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, @@ -623,11 +654,11 @@ public: static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template, static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases), static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params, - token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row, - expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, - expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, - use_lora, fc2_lora, stream, parallelism_config, config, min_latency_mode, num_active_experts_per, - active_expert_global_ids, start_expert); + token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row, + permuted_row_to_unpermuted_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows, + hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, + stream, parallelism_config, config, min_latency_mode, num_active_experts_per, active_expert_global_ids, + start_expert); } virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override @@ -643,7 +674,8 @@ public: void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, @@ -652,7 +684,8 @@ public: alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2), reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output), - reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream); + reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row, + stream); } std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> @@ -677,7 +710,7 @@ public: private: std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> setupTmaWarpSpecializedInputs( - int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, int64_t hidden_size, + int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, int64_t inter_size, int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -694,7 +727,8 @@ private: float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row, + cudaStream_t stream); static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, @@ -724,8 +758,8 @@ private: bool mayHaveFinalizeFused() const { - return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 - && !use_deterministic_hopper_reduce_ && !use_w4afp8; + return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_ + && !use_w4afp8; } // TODO: This should eventually take the quant params to give more flexibility @@ -756,12 +790,12 @@ private: WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationType fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); + ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, - float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row, + float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row, int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 2178f48db9..f1a6b9dc88 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:53b6f54a21bd547c0da17e3723b7822d4ee16b66b66a545948c0cbee5760bf65 -size 63835444 +oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d +size 66872936 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index 721c4d5e52..4af58b0800 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -21c59ede16aa448b6135327bd0f95e72a6e614f219935b8f67fe635b3cb4b38b libtensorrt_llm_internal_cutlass_kernels_static.a -commit bac309ac608d35d7d0144e594bf3e5fa8cfca796 +813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a +commit 9c0a42825905952beaf9b35d5a35d58de1a123fa diff --git a/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu b/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu index b6d78cd82d..c6ae3a6d49 100644 --- a/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu +++ b/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu @@ -188,9 +188,6 @@ __global__ void mergeAttnWithSoftmaxKernel(T* merged_attn, float2* merged_softma // load softmax stat int const global_softmax_stats_offset = (global_q_offset + local_token_idx) * num_heads + head_idx; float2 curr_stats = curr_softmax_stats[global_softmax_stats_offset]; - // hack, current softmax stats max is not multiplied by bmm1_scale - // TODO: delete this line when trtllm gen kernel return the right max value. - curr_stats.x *= 0.072168784; // 1 / sqrt(128 + 64), head_size is 128 for output, but for bmm1 is 192 float2 pre_stats = pre_softmax_stats[global_softmax_stats_offset]; // load attn diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.cu b/cpp/tensorrt_llm/kernels/mlaKernels.cu index 2849eba71d..cdb7abbb91 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.cu +++ b/cpp/tensorrt_llm/kernels/mlaKernels.cu @@ -207,8 +207,9 @@ inline __device__ void dequantCopy( template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer> __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const* fuse_buf, KVCacheBuffer kv_cache, float2 const* cos_sin_cache, size_t head_num, int head_size, int c_k, int* cu_q_seqlens, - int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type, - float const* quant_scale_kv) + int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type, float* bmm1_scale, + float* bmm2_scale, float const* quant_scale_o, float const* quant_scale_kv, float const* dequant_scale_q, + float const* dequant_scale_kv, float host_bmm1_scale) { // Constants. @@ -231,6 +232,32 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const* size_t const batch_idx = blockIdx.y; size_t const head_idx = blockIdx.z; + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) + { + + // Calculate bmm scale for FP8 MLA + if (cache_type == KvCacheDataType::FP8) + { + float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f; + float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f; + float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f; + if (bmm1_scale) + { + // The scale prepared for log2 optimization. + constexpr float kLog2e = 1.4426950408889634074f; + // The scale after fmha bmm1. + float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale; + bmm1_scale[0] = bmm1_scale_val; + bmm1_scale[1] = bmm1_scale_val * kLog2e; + } + if (bmm2_scale) + { + // The scale after fmha bmm2. + bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val; + } + } + } + if (head_idx < head_num) { size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD); @@ -919,10 +946,54 @@ void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, c { dim3 grid(int(tensorrt_llm::common::divUp(params.max_input_seq_len, 32)), params.batch_size, params.head_num + 8); auto head_size = params.meta.qk_nope_head_dim; - applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer> - <<<grid, 256, 0, stream>>>(params.attention_input_buf, params.latent_cache, kv_cache_buffer, - params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens, - params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv); + applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer><<<grid, 256, 0, stream>>>( + params.attention_input_buf, params.latent_cache, kv_cache_buffer, params.cos_sin_cache, params.head_num, + head_size, params.meta.kv_lora_rank, params.cu_q_seqlens, params.cache_seq_lens, params.max_input_seq_len, + params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_kv, + params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale); + if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr + && params.cache_type == KvCacheDataType::FP8) + { + TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8"); + + int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim); + int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim); + int const dim_v_per_head = (params.meta.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = params.head_num * dim_q_per_head; + int const total_k_dim_all_heads + = params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + + int const num_total_qkv_elements + = params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim; + float const* device_qkv_scale_ptr = params.quant_scale_qkv; + + if (num_total_qkv_elements > 0) + { + int const threads_per_block = 256; + int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block; + + TLLM_LOG_DEBUG( + "Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d", + num_blocks, threads_per_block, num_total_qkv_elements); + + tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<T><<<num_blocks, threads_per_block, 0, stream>>>( + static_cast<T const*>(params.attention_input_buf), // Source + static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination + num_total_qkv_elements, device_qkv_scale_ptr); + sync_check_cuda_error(stream); + + cudaStreamSynchronize(stream); + } + else + { + TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization."); + } + } } template <typename T, typename KVCacheBuffer> @@ -1037,6 +1108,17 @@ INSTANTIATE_SET_KVCACHE_MLA(float); INSTANTIATE_SET_KVCACHE_MLA(half); INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16); +template <typename T_IN> +__global__ void QuantizeCopyInputToFp8Kernel( + T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr) +{ + uint element_idx = threadIdx.x + blockDim.x * blockIdx.x; + if (element_idx < num_total_elements) + { + float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f; + output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast<float>(input_buffer[element_idx]) * scale_factor); + } +} } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.h b/cpp/tensorrt_llm/kernels/mlaKernels.h index 3d5aa4f148..e3472c13d8 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.h +++ b/cpp/tensorrt_llm/kernels/mlaKernels.h @@ -87,6 +87,8 @@ struct MlaParams void* context_paged_kv_ptr = nullptr; void* context_kv_cache_block_offsets_ptr = nullptr; int32_t context_paged_kv_max_blocks_per_seq = 0; + // for FP8 context qkv quantization + float const* quant_scale_qkv = nullptr; }; template <typename T, typename KVCacheBuffer> @@ -111,5 +113,9 @@ void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* late float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream); +template <typename T_IN> +__global__ void QuantizeCopyInputToFp8Kernel( + T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr); + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/quantization.cu b/cpp/tensorrt_llm/kernels/quantization.cu index e78a6c9b30..817b0a57ee 100644 --- a/cpp/tensorrt_llm/kernels/quantization.cu +++ b/cpp/tensorrt_llm/kernels/quantization.cu @@ -132,8 +132,8 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3); // FP4 Quantization template <typename T, int SF_VEC_SIZE> -void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, - bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream) +void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, + bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream) { #ifdef ENABLE_FP8 if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) @@ -143,26 +143,33 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, i dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + // The number of blocks for m. The m dimension will be padded to 128 for swizzled layout. + int numBlocksForM = layout == QuantizationSFLayout::SWIZZLED ? PadUpFn(m, 128) : m; + dim3 grid(std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. - auto* kernel_instance = useUE8M0 ? &cvt_fp8_to_fp4<SF_VEC_SIZE, true> : &cvt_fp8_to_fp4<SF_VEC_SIZE, false>; - kernel_instance<<<grid, block, 0, stream>>>( - m, n, input, SFScale, reinterpret_cast<uint64_t*>(output), reinterpret_cast<uint32_t*>(SFOuput), layout); + auto* kernel_instance = useUE8M0 + ? &quantize_with_block_size<BlockScaleQuantizationType::FP8_TO_FP4, T, SF_VEC_SIZE, true> + : &quantize_with_block_size<BlockScaleQuantizationType::FP8_TO_FP4, T, SF_VEC_SIZE, false>; + kernel_instance<<<grid, block, 0, stream>>>(b, m, n, n, input, SFScale, reinterpret_cast<uint32_t*>(output), + reinterpret_cast<uint32_t*>(SFOuput), layout); } else #endif { // Grid, Block size. // Each thread converts 8 values. - dim3 block(std::min(int(n / CVT_FP4_ELTS_PER_THREAD), 512)); + dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + // The number of blocks for m. The m dimension will be padded to 128 for swizzled layout. + int numBlocksForM = layout == QuantizationSFLayout::SWIZZLED ? PadUpFn(m, 128) : m; + dim3 grid(std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. - auto* kernel_instance - = useUE8M0 ? &cvt_fp16_to_fp4<T, SF_VEC_SIZE, true> : &cvt_fp16_to_fp4<T, SF_VEC_SIZE, false>; + auto* kernel_instance = useUE8M0 + ? &quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE, true> + : &quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE, false>; cudaLaunchConfig_t config; config.gridDim = grid; config.blockDim = block; @@ -173,63 +180,51 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, i attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance, m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), + cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput), layout); } } -template <typename T, int SF_VEC_SIZE> -void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* SFScale, int64_t* output, - int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream) +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MXFP8 Quantization + +template <typename T> +void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output, int32_t* SFOuput, + QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream) { -#ifdef ENABLE_FP8 - if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) - { - // Grid, Block size. - // Each thread converts 16 values. - dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); - // Get number of blocks per SM (assume we can fully utilize the SM). - int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM)); + // Fixed SF_VEC_SIZE as 32 + static constexpr int SF_VEC_SIZE = 32; - // Launch the cvt kernel. - auto* kernel_instance - = useUE8M0 ? &cvt_fp8_to_fp4_3d<SF_VEC_SIZE, true> : &cvt_fp8_to_fp4_3d<SF_VEC_SIZE, false>; - kernel_instance<<<grid, block, 0, stream>>>(b, m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), - reinterpret_cast<uint32_t*>(SFOuput), FP4QuantizationSFLayout::SWIZZLED); - } - else -#endif - { - // Grid, Block size. - // Each thread converts 8 values. - dim3 block(std::min(int(n / CVT_FP4_ELTS_PER_THREAD), 512)); - // Get number of blocks per SM (assume we can fully utilize the SM). - int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = std::max(1u, 2048u / block.x); + // The number of blocks for m. The m dimension will be padded to 128 for swizzled layout. + int numBlocksForM = layout == QuantizationSFLayout::SWIZZLED ? PadUpFn(m, 128) : m; + dim3 grid(std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM)); - // Launch the cvt kernel. - auto* kernel_instance - = useUE8M0 ? &cvt_fp16_to_fp4_3d<T, SF_VEC_SIZE, true> : &cvt_fp16_to_fp4_3d<T, SF_VEC_SIZE, false>; - cudaLaunchConfig_t config; - config.gridDim = grid; - config.blockDim = block; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance, b, m, n, input, SFScale, reinterpret_cast<uint32_t*>(output), - reinterpret_cast<uint32_t*>(SFOuput), FP4QuantizationSFLayout::SWIZZLED); - } + // Launch the cvt kernel. + cudaLaunchConfig_t config; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, + quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b, m, n, padded_n, + input, nullptr, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput), layout); } -__global__ void nvfp4_block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded, int numCols, +//////////////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded, int numCols, int numColsPadded, uint8_t const* SFIn, uint8_t* SFOutput) { - constexpr int SF_VEC_SIZE = 16; for (int rowIdx = blockIdx.x; rowIdx < numRowsPadded; rowIdx += gridDim.x) { for (int batchIdx = 0; batchIdx < numBatches; batchIdx++) @@ -250,18 +245,16 @@ __global__ void nvfp4_block_scale_interleave_kernel(int numBatches, int numRows, // int const numSfTilesK = (numCols + 4 - 1) / 4; // int const tileOffset = ((mi / 128) * numSfTilesK + ki / 4) * 512; // int const dstIdx = tileOffset + (mi % 32) * 16 + ((mi % 128) / 32) * 4 + ki % 4; - auto dstIdx = get_sf_out_offset_128x4<SF_VEC_SIZE>( - batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * SF_VEC_SIZE); + auto dstIdx = get_sf_out_offset_128x4(batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols); SFOutput[dstIdx] = sf; } } } } -__global__ void nvfp4_block_scale_interleave_reverse_kernel( +__global__ void block_scale_interleave_reverse_kernel( int numBatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput) { - constexpr int SF_VEC_SIZE = 16; for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { for (int batchIdx = 0; batchIdx < numBatches; batchIdx++) @@ -272,8 +265,7 @@ __global__ void nvfp4_block_scale_interleave_reverse_kernel( std::optional<int> numRowsOpt = numRows; // Get the swizzled input index using the same swizzling pattern - auto srcIdx = get_sf_out_offset_128x4<SF_VEC_SIZE>( - batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * SF_VEC_SIZE); + auto srcIdx = get_sf_out_offset_128x4(batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols); auto sf = SFIn[srcIdx]; // Output goes to linear layout @@ -285,8 +277,8 @@ __global__ void nvfp4_block_scale_interleave_reverse_kernel( } // This is intended for weight loading, so m and n are large, b <= 256 -void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn, - uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream) +void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn, uint8_t* SFOutput, + int multiProcessorCount, cudaStream_t stream) { // Each thread reads 1 int8 value dim3 block(std::min(n_padded, 1024)); @@ -294,11 +286,11 @@ void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_pa int const numBlocksPerSM = std::max(1u, 4096u / block.x); dim3 grid(std::min(m_padded, multiProcessorCount * numBlocksPerSM)); - nvfp4_block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn, SFOutput); + block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, m_padded, n, n_padded, SFIn, SFOutput); } // This is intended for weight loading, so m and n are large, b <= 256 -void invokeNVFP4BlockScaleInterleaveReverse( +void invokeBlockScaleInterleaveReverse( int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream) { // Each thread reads 1 int8 value @@ -307,46 +299,36 @@ void invokeNVFP4BlockScaleInterleaveReverse( int const numBlocksPerSM = std::max(1u, 4096u / block.x); dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM)); - nvfp4_block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput); + block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput); } // Instantiate the function. -template void invokeFP4Quantization<half, 16>(int m, int n, half const* input, float const* SFScale, int64_t* output, - int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream); -template void invokeFP4Quantization<half, 32>(int m, int n, half const* input, float const* SFScale, int64_t* output, - int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream); -template void invokeBatchedFP4Quantization<half, 16>(int b, int m, int n, half const* input, float const* SFScale, - int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream); -template void invokeBatchedFP4Quantization<half, 32>(int b, int m, int n, half const* input, float const* SFScale, - int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream); +template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* input, float const* SFScale, + int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, + cudaStream_t stream); +template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* input, float const* SFScale, + int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, + cudaStream_t stream); +template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input, int64_t* output, + int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input, float const* SFScale, - int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, - cudaStream_t stream); -template void invokeFP4Quantization<__nv_bfloat16, 32>(int m, int n, __nv_bfloat16 const* input, float const* SFScale, - int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, - cudaStream_t stream); -template void invokeBatchedFP4Quantization<__nv_bfloat16, 16>(int b, int m, int n, __nv_bfloat16 const* input, - float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, - cudaStream_t stream); -template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(int b, int m, int n, __nv_bfloat16 const* input, - float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, - cudaStream_t stream); +template void invokeFP4Quantization<__nv_bfloat16, 16>(int b, int m, int n, __nv_bfloat16 const* input, + float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, + int multiProcessorCount, cudaStream_t stream); +template void invokeFP4Quantization<__nv_bfloat16, 32>(int b, int m, int n, __nv_bfloat16 const* input, + float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, + int multiProcessorCount, cudaStream_t stream); +template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n, __nv_bfloat16 const* input, + int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream); #endif #ifdef ENABLE_FP8 -template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, - int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, - cudaStream_t stream); -template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, - int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, - cudaStream_t stream); -template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 16>(int b, int m, int n, __nv_fp8_e4m3 const* input, - float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, - cudaStream_t stream); -template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 32>(int b, int m, int n, __nv_fp8_e4m3 const* input, - float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, - cudaStream_t stream); +template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(int b, int m, int n, __nv_fp8_e4m3 const* input, + float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, + int multiProcessorCount, cudaStream_t stream); +template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(int b, int m, int n, __nv_fp8_e4m3 const* input, + float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, + int multiProcessorCount, cudaStream_t stream); #endif } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/quantization.cuh b/cpp/tensorrt_llm/kernels/quantization.cuh index 95768fecfe..7aacc0f31d 100644 --- a/cpp/tensorrt_llm/kernels/quantization.cuh +++ b/cpp/tensorrt_llm/kernels/quantization.cuh @@ -272,10 +272,9 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const nu } //////////////////////////////////////////////////////////////////////////////////////////////////// -// FP4 Quantization +// FP4/MXFP8 Quantization -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; +constexpr int CVT_ELTS_PER_THREAD = 8; constexpr int CVT_FP4_THREADS_PER_WARP = 32; constexpr int CVT_FP8_TO_FP4_ELTS_PER_THREAD = 16; @@ -373,31 +372,6 @@ inline __device__ uint64_t fp32_vec_to_e2m1(float2 (&array)[8]) #endif } -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) -{ - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -// Define a 16 bytes packed data type. -template <class Type> -struct PackedVec -{ - typename TypeConverter<Type>::Type elts[4]; - static_assert(sizeof(elts) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vector size should match the number of elements per thread."); -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> -{ - __nv_fp8x2_e4m3 elts[8]; - static_assert(sizeof(elts) == sizeof(__nv_fp8_e4m3) * CVT_FP8_TO_FP4_ELTS_PER_THREAD, - "Vector size should match the number of elements per thread."); -}; - // Convert 4 float2 values into 8 e4m3 values (represented as one uint64_t). inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) { @@ -416,77 +390,37 @@ inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) return u.val; } -// Quantizes the provided PackedVec into the uint64_t output -template <class Type, int SF_VEC_SIZE> -__device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec<Type>& vec, uint8_t* SFout) +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = cuda_abs(vec.elts[0]); - -// Local maximum value. -#pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) - { - localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); - } - - constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; - // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). - localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - if constexpr (CVT_NUM_THREADS_PER_SF == 4) - { - localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); - } - // Get the final absolute maximum values. - float vecMax = float(cuda_max(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of mxfp8). - float SFValue = vecMax * reciprocal_approximate_ftz(448.0f); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - __nv_fp8_e8m0 tmpSFVal; - tmpSFVal.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); - float SFValueNarrow = static_cast<float>(tmpSFVal); - fp8SFVal = tmpSFVal.__x; - // Get the output scale (reciprocal of the SFValue). - float outputScale = SFValue != 0.f ? reciprocal_approximate_ftz(SFValueNarrow) : 0.0f; - - if (SFout) - { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - -#pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) - { - if constexpr (std::is_same_v<Type, half>) - { - fp2Vals[i] = __half22float2(vec.elts[i]); - } - else - { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e4m3 values. - uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals); - - // Write the e4m3 values to global memory. - return e4m3Vec; -#else - return 0; -#endif + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; } +__device__ __forceinline__ float exp2f_rcp(uint8_t exp) +{ + constexpr uint32_t FP32_EXPONENT_BIAS = 127; + return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(exp)); +} + +// Define a 16 bytes packed data type. +template <class Type> +struct PackedVec +{ + typename TypeConverter<Type>::Type elts[4]; + static_assert(sizeof(elts) == sizeof(Type) * CVT_ELTS_PER_THREAD, + "Vector size should match the number of elements per thread."); +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> +{ + __nv_fp8x2_e4m3 elts[8]; + static_assert(sizeof(elts) == sizeof(__nv_fp8_e4m3) * CVT_FP8_TO_FP4_ELTS_PER_THREAD, + "Vector size should match the number of elements per thread."); +}; + // Quantizes the provided PackedVec into the uint32_t output template <class Type, int SF_VEC_SIZE, bool UE8M0_SF> __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) @@ -497,12 +431,12 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, // Local maximum value. #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); } - constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); if constexpr (CVT_NUM_THREADS_PER_SF == 4) @@ -512,32 +446,33 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, // Get the final absolute maximum values. float vecMax = float(cuda_max(localMax.x, localMax.y)); - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - float SFValueNarrow; // 8 bits representation of the SF. uint8_t fp8SFVal; + float outputScale; // Write the SF to global memory (STG.8). if constexpr (UE8M0_SF) { __nv_fp8_e8m0 tmp; - tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); - SFValueNarrow = static_cast<float>(tmp); + // Scale the max value to the range of E2m1. + vecMax *= reciprocal_approximate_ftz(6.0f); + tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf); fp8SFVal = tmp.__x; + outputScale = vecMax != 0 ? exp2f_rcp(fp8SFVal) : 0.0f; } else { + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + auto SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); // Here SFValue is always positive, so E4M3 is the same as UE4M3. __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); fp8SFVal = tmp.__x; - SFValueNarrow = static_cast<float>(tmp); + SFValue = static_cast<float>(tmp); + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal)) * reciprocal(SFScaleVal)) + outputScale = vecMax != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) - float outputScale - = SFValue != 0 ? reciprocal_approximate_ftz(SFValueNarrow * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; if (SFout) { @@ -546,10 +481,10 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, } // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { if constexpr (std::is_same_v<Type, half>) { @@ -614,6 +549,7 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal, // maximum value of e2m1 = 6.0. // TODO: use half as compute data type. float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + float SFValueNarrow; // 8 bits representation of the SF. uint8_t fp8SFVal; // Write the SF to global memory (STG.8). @@ -621,7 +557,7 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal, { __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); - SFValue = static_cast<float>(tmp); + SFValueNarrow = static_cast<float>(tmp); fp8SFVal = tmp.__x; } else @@ -629,11 +565,11 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal, // Here SFValue is always positive, so E4M3 is the same as UE4M3. __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); fp8SFVal = tmp.__x; - SFValue = static_cast<float>(tmp); + SFValueNarrow = static_cast<float>(tmp); } // Get the output scale. // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) - float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValue) : 0.0f; + float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValueNarrow) : 0.0f; if (SFout) { @@ -662,9 +598,79 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal, #endif } -template <int SF_VEC_SIZE> -inline __device__ __host__ int64_t get_sf_out_offset_128x4( - std::optional<int> batchIdx, int mIdx, int kIdx, std::optional<int> numRows, int numCols) +// Quantizes the provided PackedVec into the uint64_t output +template <class Type, int SF_VEC_SIZE> +__device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec<Type>& vec, uint8_t* SFout) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = cuda_abs(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) + { + localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); + } + + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; + // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + if constexpr (CVT_NUM_THREADS_PER_SF == 4) + { + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(cuda_max(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of mxfp8). + float SFValue = vecMax * reciprocal_approximate_ftz(448.0f); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + __nv_fp8_e8m0 tmpSFVal; + tmpSFVal.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast<float>(tmpSFVal); + fp8SFVal = tmpSFVal.__x; + // Get the output scale (reciprocal of the SFValue). + float outputScale = vecMax != 0.f ? reciprocal_approximate_ftz(SFValue) : 0.0f; + + if (SFout) + { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) + { + if constexpr (std::is_same_v<Type, half>) + { + fp2Vals[i] = __half22float2(vec.elts[i]); + } + else + { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e4m3 values. + uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals); + + // Write the e4m3 values to global memory. + return e4m3Vec; +#else + return 0; +#endif +} + +inline __host__ __device__ int64_t get_sf_out_offset_128x4( + std::optional<int> batchIdx, int mIdx, int kIdx, std::optional<int> numRows, int numColVecs) { // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] @@ -686,9 +692,9 @@ inline __device__ __host__ int64_t get_sf_out_offset_128x4( int32_t kTileIdx = (kIdx / 4); int64_t kTileStride = 32 * outerMStride; // 512 - // SF vector size 16. We round the "numCols" up to a multiple of 64. - int factor = SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; + // SF vector size 16 or 32. We round the "numCols" up to a multiple of 64 or 128. + // It is the same as rounding the "numColVecs" up to a multiple of 4. + int32_t numKTiles = (numColVecs + 4 - 1) / 4; int32_t mTileIdx = mIdx / (32 * 4); int64_t mTileStride = numKTiles * kTileStride; @@ -703,35 +709,35 @@ inline __device__ __host__ int64_t get_sf_out_offset_128x4( return SFOffset; } -template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF, int SF_VEC_SIZE> -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx, int colIdx, - std::optional<int> numRows, int numCols, SFType* SFout, FP4QuantizationSFLayout layout) +template <class SFType, int CVT_NUM_THREADS_PER_SF> +__device__ uint8_t* cvt_quant_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx, int colVecIdx, + std::optional<int> numRows, int numColVecs, SFType* SFout, QuantizationSFLayout layout) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert( - CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2 || CVT_FP4_NUM_THREADS_PER_SF == 4); + // Each thread holds one vector. + static_assert(CVT_NUM_THREADS_PER_SF == 1 || CVT_NUM_THREADS_PER_SF == 2 || CVT_NUM_THREADS_PER_SF == 4); // One pair of threads write one SF to global memory. // TODO: stage through smem for packed STG.32 // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) + if (threadIdx.x % CVT_NUM_THREADS_PER_SF == 0) { - if (layout == FP4QuantizationSFLayout::SWIZZLED) + if (layout == QuantizationSFLayout::SWIZZLED) { // SF vector index (16 elements share one SF in the K dimension). // numRows and numCols are unpadded. - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t kIdx = colVecIdx / CVT_NUM_THREADS_PER_SF; int32_t mIdx = rowIdx; - auto SFOffset = get_sf_out_offset_128x4<SF_VEC_SIZE>(batchIdx, mIdx, kIdx, numRows, numCols); + auto SFOffset = get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numColVecs); return reinterpret_cast<uint8_t*>(SFout) + SFOffset; } - else if (layout == FP4QuantizationSFLayout::LINEAR) + else if (layout == QuantizationSFLayout::LINEAR) { // Linear row-major layout, no padding required. - int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t KTileIdx = colVecIdx / CVT_NUM_THREADS_PER_SF; - int32_t numKTiles = numCols / SF_VEC_SIZE; + int32_t numKTiles = numColVecs; int64_t mTileStride = numKTiles; int64_t BTileStride = numRows.value_or(0) * mTileStride; @@ -748,50 +754,109 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI return nullptr; } -// Use UE4M3 by default. -template <class Type, int SF_VEC_SIZE, bool UE8M0_SF> +template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF> __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - __launch_bounds__(512, 4) cvt_fp16_to_fp4_3d( + __launch_bounds__(512, 4) quantize_with_block_size( #else -cvt_fp16_to_fp4_3d( +quantize_with_block_size( #endif - int32_t numbatches, int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, - uint32_t* SFout, FP4QuantizationSFLayout layout) + int32_t numbatches, int32_t numRows, int32_t numCols, int32_t numPaddedCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, QuantizationSFLayout layout) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // The elements per thread. + static constexpr int ELTS_PER_THREAD = quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 + ? CVT_FP8_TO_FP4_ELTS_PER_THREAD + : CVT_ELTS_PER_THREAD; + using PackedVec = PackedVec<Type>; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; // 2 or 4 - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD; // 2 or 4 + static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched."); // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)). float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + // Is it swizzled layout? + bool isSfSwizzledLayout = layout == QuantizationSFLayout::SWIZZLED; + + // The number of padded rows considering 128x4 SF layout. + int numPaddedRowsForSf = isSfSwizzledLayout ? PadUpFn(numRows, 128) : numRows; + int numColsForSf = isSfSwizzledLayout ? PadUpFn(numPaddedCols, 4 * SF_VEC_SIZE) : numPaddedCols; + + // The number of threads in the column dimension。 + // Note that numCols/numPaddedCols/numColsForSf are guaranteed to be multiples of ELTS_PER_THREAD. + int numColThreads = numCols / ELTS_PER_THREAD; + int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD; + int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD; + asm volatile("griddepcontrol.wait;"); // Input tensor batch/row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) + for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) { for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { - for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) + for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { - int64_t inOffset = batchIdx * numRows * (numCols / CVT_FP4_ELTS_PER_THREAD) - + rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - std::optional<int> optionalBatchIdx = batchIdx; std::optional<int> optionalNumRows = numRows; - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF, SF_VEC_SIZE>( - optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numCols, SFout, layout); + // The SF output pointer. + auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>( + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout); - out_pos = cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out); + // The input tensor offset. + int64_t inOffset = static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx; + int64_t outOffset = static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; + + // Set the values to 0 of those are padded columns. + if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) + { + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) + { + reinterpret_cast<uint32_t*>(out)[outOffset] = 0u; + } + else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 + || quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) + { + reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull; + } + } + + // Set the SF padding to 0. + if (rowIdx >= numRows || colIdx >= numColThreads) + { + // Set the SF padding to 0. + if (sf_out != nullptr) + { + sf_out[0] = 0x00; + } + } + else + { + // Load the input vector. + PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; + + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) + { + reinterpret_cast<uint32_t*>(out)[outOffset] + = cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out); + } + else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) + { + reinterpret_cast<uint64_t*>(out)[outOffset] + = cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out); + } + else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) + { + reinterpret_cast<uint64_t*>(out)[outOffset] + = cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out); + } + } } } } @@ -799,141 +864,7 @@ cvt_fp16_to_fp4_3d( #endif } -// Use UE4M3 by default. -template <int SF_VEC_SIZE, bool UE8M0_SF> -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - __launch_bounds__(512, 4) cvt_fp8_to_fp4_3d( -#else -cvt_fp8_to_fp4_3d( -#endif - int32_t numbatches, int32_t numRows, int32_t numCols, __nv_fp8_e4m3 const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, FP4QuantizationSFLayout layout) -{ -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - using PackedVec = PackedVec<__nv_fp8_e4m3>; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP8_TO_FP4_ELTS_PER_THREAD; - static_assert( - sizeof(PackedVec) == sizeof(__nv_fp8_e4m3) * CVT_FP8_TO_FP4_ELTS_PER_THREAD, "Vec size is not matched."); - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; - - // Input tensor batch/row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) - { - for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) - { - for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP8_TO_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) - { - int64_t inOffset = batchIdx * numRows * (numCols / CVT_FP4_ELTS_PER_THREAD) - + rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; - // Get the output tensor offset. - // Same as inOffset because 16 elements are packed into one uint64_t. - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - - std::optional<int> optionalBatchIdx = batchIdx; - std::optional<int> optionalNumRows = numRows; - - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF, SF_VEC_SIZE>( - optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numCols, SFout, layout); - - out_pos = cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out); - } - } - } -#endif -} - -// Use UE4M3 by default. -template <class Type, int SF_VEC_SIZE, bool UE8M0_SF> -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - __launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout, - FP4QuantizationSFLayout layout) -{ -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - using PackedVec = PackedVec<Type>; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; - - asm volatile("griddepcontrol.wait;"); - // Input tensor row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) - { - for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) - { - int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF, SF_VEC_SIZE>( - std::nullopt /* batchIdx */, rowIdx, colIdx, std::nullopt /* numRows */, numCols, SFout, layout); - - out_pos = cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out); - } - } - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} - -// Use UE4M3 by default. -template <int SF_VEC_SIZE, bool UE8M0_SF> -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - __launch_bounds__(512, 4) cvt_fp8_to_fp4( -#else -cvt_fp8_to_fp4( -#endif - int32_t numRows, int32_t numCols, __nv_fp8_e4m3 const* in, float const* SFScale, uint64_t* out, uint32_t* SFout, - FP4QuantizationSFLayout layout) -{ -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - using PackedVec = PackedVec<__nv_fp8_e4m3>; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP8_TO_FP4_ELTS_PER_THREAD; - static_assert( - sizeof(PackedVec) == sizeof(__nv_fp8_e4m3) * CVT_FP8_TO_FP4_ELTS_PER_THREAD, "Vec size is not matched."); - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; - - // Input tensor row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) - { - for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP8_TO_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) - { - int64_t inOffset = rowIdx * (numCols / CVT_FP8_TO_FP4_ELTS_PER_THREAD) + colIdx; - PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; - // Get the output tensor offset. - // Same as inOffset because 16 elements are packed into one uint64_t. - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF, SF_VEC_SIZE>( - std::nullopt /* batchIdx */, rowIdx, colIdx, std::nullopt /* numRows */, numCols, SFout, layout); - - out_pos = cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out); - } - } -#endif -} - -__global__ void nvfp4_block_scale_interleave_kernel( +__global__ void block_scale_interleave_kernel( int numbatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/quantization.h b/cpp/tensorrt_llm/kernels/quantization.h index 5112123d21..160a54428a 100644 --- a/cpp/tensorrt_llm/kernels/quantization.h +++ b/cpp/tensorrt_llm/kernels/quantization.h @@ -22,7 +22,7 @@ namespace tensorrt_llm { -enum class FP4QuantizationSFLayout +enum class QuantizationSFLayout { // Block scale factors are stored in swizzled layout for cutlass FP4 kernel. Scale factor // blocks are organized in 512-byte blocks in global memory, with each block having 128x4 FP8 values. @@ -39,19 +39,27 @@ enum class FP4QuantizationSFLayout LINEAR }; +// This denotes the input and output data types of the block scale quantization. +enum class BlockScaleQuantizationType +{ + FP16_TO_FP4 = 0, + FP8_TO_FP4 = 1, + FP16_TO_MXFP8 = 2, +}; + #define PadUpFn(X, Y) ((X + Y - 1) / (Y) * (Y)) // totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed. -inline int computeFP4SwizzledLayoutSFSize(int totalRow, int totalColumn) +inline int64_t computeSwizzledLayoutSFSize(int totalRow, int totalColumn) { int paddedRow = PadUpFn(totalRow, 128); int paddedColumn = PadUpFn(totalColumn, 4); - return paddedRow * paddedColumn; + return static_cast<int64_t>(paddedRow) * paddedColumn; } -inline int computeFP4LinearLayoutSFSize(int totalRow, int totalColumn) +inline int64_t computeLinearLayoutSFSize(int totalRow, int totalColumn) { - return totalRow * totalColumn; + return static_cast<int64_t>(totalRow) * totalColumn; } namespace kernels @@ -67,17 +75,17 @@ void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows cudaStream_t stream = 0); template <typename T, int SF_VEC_SIZE = 16> -void invokeFP4Quantization(int m, int n, T const* input, float const* globalScale, int64_t* output, int32_t* SFOuput, - bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream = 0); +void invokeFP4Quantization(int b, int m, int n, T const* input, float const* globalScale, int64_t* output, + int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream = 0); -template <typename T, int SF_VEC_SIZE = 16> -void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* globalScale, int64_t* output, - int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream = 0); +template <typename T> +void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output, int32_t* SFOuput, + QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream = 0); -void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn, - uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0); +void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn, uint8_t* SFOutput, + int multiProcessorCount, cudaStream_t stream = 0); -void invokeNVFP4BlockScaleInterleaveReverse( +void invokeBlockScaleInterleaveReverse( int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu b/cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu index 82bc08cf84..d4bafb3db6 100644 --- a/cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu +++ b/cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu @@ -31,13 +31,11 @@ namespace kernels template <typename Tout> __global__ void reduce4ring_attention( // this is the accumulated results for all finished ring attention blocks - Tout* __restrict__ accu_output, // b x s_block x h x d - float* __restrict__ accu_softmax_sum, // b x s_block x h - float* __restrict__ accu_max, // b x s_block x h + Tout* __restrict__ accu_output, // b x s_block x h x d + float* __restrict__ accu_softmax_stats, // b x s_block x h x 2 (max/sum) // this is the new ring attention block results - Tout* __restrict__ output, // b x s_block x h x d - float* __restrict__ softmax_sum, // b x s_block x h - float* __restrict__ max, // b x s_block x h + Tout* __restrict__ output, // b x s_block x h x d + float* __restrict__ softmax_stats, // b x s_block x h x 2 (max/sum) // necessary constant parameters int const b, int const s_block, int const h, int const d, int const block_seq_len, int* cu_seqlens) { @@ -48,7 +46,12 @@ __global__ void reduce4ring_attention( int block_s_end = (block_seq_idx + 1) * block_seq_len; block_s_end = s_block < block_s_end ? s_block : block_s_end; int64_t output_start_offset = batchid * s_block * d + block_s_start * d; - int64_t lm_start_offset = batchid * s_block + block_s_start; + int64_t lm_start_offset = (batchid * s_block + block_s_start) * 2; + + float* accu_softmax_sum = accu_softmax_stats + 1; + float* accu_max = accu_softmax_stats; + float* softmax_sum = softmax_stats + 1; + float* max = softmax_stats; __shared__ cuda::barrier<cuda::thread_scope::thread_scope_block> barrier; if (block.thread_rank() == 0) @@ -68,7 +71,7 @@ __global__ void reduce4ring_attention( float scaled_my_ss1_ = 1.0, scaled_my_ss2_ = 1.0; if (s_ < s_len) { - uint64_t lm_start_offset_ = lm_start_offset + s_; + uint64_t lm_start_offset_ = lm_start_offset + s_ * 2; float my_accu_ss = accu_softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : accu_softmax_sum[lm_start_offset_]; float my_ss = softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : softmax_sum[lm_start_offset_]; @@ -123,8 +126,8 @@ void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* out int dim_s = (s + block_seq_len - 1) / block_seq_len; dim3 block_num(b, dim_s, 1); - reduce4ring_attention<Tout><<<block_num, threads_per_block, 0, stream>>>(accu_output, accu_softmax_sum, - accu_softmax_max, output, softmax_sum, softmax_max, b, s, h, d, block_seq_len, cu_seqlens); + reduce4ring_attention<Tout><<<block_num, threads_per_block, 0, stream>>>( + accu_output, accu_softmax_stats, output, softmax_stats, b, s, h, d, block_seq_len, cu_seqlens); } #define INSTANTIATE_RECOVER_RA(Tout) \ diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h index 5ad6a96371..ace034dc43 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h +++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h @@ -16,6 +16,7 @@ */ #pragma once +#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/runtime/common.h" @@ -161,9 +162,8 @@ struct TopKSamplingKernelParams } TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0); - - TLLM_CHECK(0 < maxTopP && maxTopP <= 1.f); - TLLM_CHECK(0 <= maxTopK && maxTopK <= TOP_K_MAX); + TLLM_CHECK_WITH_INFO(0 < maxTopP && maxTopP <= 1.f, "maxTopP (%f) is out of range", maxTopP); + TLLM_CHECK_WITH_INFO(0 <= maxTopK && maxTopK <= TOP_K_MAX, "maxTopK (%d) is out of range", maxTopK); TLLM_CHECK((skipOutputIdCurrentStep && outputIdCurrentStep && returnAllSelectedTokens) || (skipOutputIdCurrentStep == nullptr && outputIdCurrentStep == nullptr)); } diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp index 8122288524..e0f2d5cce2 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp @@ -14,47 +14,58 @@ * limitations under the License. */ #include "ub_allocator.h" +#include "tensorrt_llm/common/opUtils.h" +#include <set> +#include <stdexcept> namespace tensorrt_llm::runtime::ub { UserBufferAllocator& UserBufferAllocator::Instance() { - static UserBufferAllocator _; - return _; -} - -void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& world_config) -{ - if (!is_initialized()) + if (use_nccl_symmetric) { - ub_comm_ = nullptr; - world_config_ = world_config; - create_communicator_grouped2(&ub_comm_, world_config_); - TLLM_CHECK(ub_comm_ != nullptr); - is_initialized_ = true; + static NCCLUserBufferAllocator _; + return _; + } + else + { + static UserBufferAllocator _; + return _; } } -bool UserBufferAllocator::is_initialized() +void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) { - return is_initialized_; + if (!isInitialized()) + { + mUbComm = nullptr; + mWorldConfig = worldConfig; + create_communicator_grouped2(&mUbComm, worldConfig); + TLLM_CHECK(mUbComm != nullptr); + mIsInitialized = true; + } } -UBBuffer UserBufferAllocator::register_ub_buffer(size_t bytes) +bool UserBufferAllocator::isInitialized() { - TLLM_CHECK(is_initialized()); + return mIsInitialized; +} + +UBBuffer UserBufferAllocator::registerUBBuffer(size_t bytes) +{ + TLLM_CHECK(isInitialized()); void* addr = nullptr; int handle = -1; - handle = register_user_buffer_collective((void**) &addr, bytes, ub_comm_); + handle = register_user_buffer_collective((void**) &addr, bytes, mUbComm); return {addr, handle, bytes}; } UBBuffer UserBufferAllocator::allocate(size_t bytes) { - TLLM_CHECK(is_initialized()); - auto ub_buffer = register_ub_buffer(bytes); + TLLM_CHECK(isInitialized()); + auto ub_buffer = registerUBBuffer(bytes); TLLM_CHECK(!ub_buffer.invalid()); - buffers_.push_back(ub_buffer); + mBuffers.push_back(ub_buffer); return ub_buffer; } @@ -62,13 +73,177 @@ void UserBufferAllocator::deallocate(void* addr) {} UBBuffer UserBufferAllocator::get(int idx) { - TLLM_CHECK(is_initialized() && idx < buffers_.size() && !buffers_[idx].invalid()); - return buffers_[idx]; + TLLM_CHECK(isInitialized() && idx < mBuffers.size() && !mBuffers[idx].invalid()); + return mBuffers[idx]; } communicator* UserBufferAllocator::comm() { - TLLM_CHECK(is_initialized()); - return ub_comm_; + TLLM_CHECK(isInitialized()); + return mUbComm; } + +void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) +{ + if (!isInitialized()) + { + TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator"); + std::set<int> group; + for (int i = 0; i < worldConfig.getSize(); i++) + { + group.insert(i); + } + mComm = getComm(group); + mIsInitialized = true; + } +} + +UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes) +{ + TLLM_CHECK(isInitialized()); + UBBuffer ub_buffer; + + auto& ncclHelper = getNCCLHelper(); + if (!ncclHelper.isLoaded()) + { + TLLM_THROW("NCCL library could not be loaded for dynamic symbol access"); + } + + auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc(); + auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister(); + + NCCLCHECK(ncclMemAllocFunc(&ub_buffer.addr, bytes)); + NCCLCHECK(ncclCommWindowRegisterFunc((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC)); + ub_buffer.handle = 5; + ub_buffer.size = bytes; + return ub_buffer; +} + +// Static member definitions +std::unique_ptr<NCCLHelper> NCCLUserBufferAllocator::mNCCLHelper = nullptr; + +NCCLHelper& NCCLUserBufferAllocator::getNCCLHelper() +{ + if (!mNCCLHelper) + { + mNCCLHelper = std::make_unique<NCCLHelper>(); + } + return *mNCCLHelper; +} + +// NCCLHelper implementation +NCCLHelper::NCCLHelper() + : mLibraryHandle(nullptr) + , mNCCLCommWindowRegister(nullptr) + , mNCCLMemAlloc(nullptr) + , mIsLoaded(false) +{ + loadNCCLLibrary(); +} + +NCCLHelper::~NCCLHelper() +{ + if (mLibraryHandle) + { +#ifdef _WIN32 + FreeLibrary(mLibraryHandle); +#else + dlclose(mLibraryHandle); +#endif + mLibraryHandle = nullptr; + } +} + +void NCCLHelper::loadNCCLLibrary() +{ + try + { +#ifdef _WIN32 + char const* libraryNames[] = {"nccl.dll"}; +#else + char const* libraryNames[] = {"libnccl.so"}; +#endif + + for (int i = 0; libraryNames[i] != nullptr; ++i) + { + mLibraryHandle = loadLibraryHandle(libraryNames[i]); + if (mLibraryHandle) + { + TLLM_LOG_INFO("Successfully loaded NCCL library: %s", libraryNames[i]); + break; + } + } + + if (!mLibraryHandle) + { + TLLM_LOG_WARNING("Failed to load NCCL library"); + return; + } + + // Load the required symbols + mNCCLCommWindowRegister + = reinterpret_cast<ncclCommWindowRegisterFunc>(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister")); + + mNCCLMemAlloc = reinterpret_cast<ncclMemAllocFunc>(getSymbolAddress(mLibraryHandle, "ncclMemAlloc")); + + if (mNCCLCommWindowRegister == nullptr) + { + TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported."); + } + + if (mNCCLMemAlloc) + { + mIsLoaded = true; + } + else + { + TLLM_LOG_WARNING("Failed to load required NCCL symbols"); + } + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what()); + } +} + +void* NCCLHelper::loadLibraryHandle(char const* libName) +{ +#ifdef _WIN32 + return LoadLibraryA(libName); +#else + return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL); +#endif +} + +void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName) +{ + if (!handle) + { + return nullptr; + } + +#ifdef _WIN32 + return GetProcAddress(static_cast<HMODULE>(handle), symbolName); +#else + return dlsym(handle, symbolName); +#endif +} + +NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister() +{ + return mNCCLCommWindowRegister; +} + +NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc() +{ + return mNCCLMemAlloc; +} + +bool NCCLHelper::isLoaded() const +{ + return mIsLoaded; +} + +bool UserBufferAllocator::use_nccl_symmetric = false; + }; // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h index 9e5c2ee4cb..37a48e5035 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h @@ -14,9 +14,16 @@ * limitations under the License. */ #pragma once +#include "nccl.h" #include "tensorrt_llm/runtime/worldConfig.h" +#include <memory> #if ENABLE_MULTI_DEVICE #include "userbuffers.h" +#ifdef _WIN32 +#include <windows.h> +#else +#include <dlfcn.h> +#endif #endif namespace tensorrt_llm::runtime::ub @@ -28,11 +35,13 @@ struct UBBuffer void* addr; int handle; size_t size; + ncclWindow_t window; - UBBuffer(void* a = nullptr, int h = -1, size_t s = 0) + UBBuffer(void* a = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr) : addr(a) , handle(h) , size(s) + , window(w) { } @@ -49,21 +58,74 @@ public: UserBufferAllocator() = default; - void initialize(tensorrt_llm::runtime::WorldConfig const& world_config); - bool is_initialized(); + virtual void initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig); + bool isInitialized(); UBBuffer allocate(size_t bytes); void deallocate(void* addr); UBBuffer get(int idx); communicator* comm(); + virtual UBBuffer registerUBBuffer(size_t bytes); + + static bool use_nccl_symmetric; private: - UBBuffer register_ub_buffer(size_t bytes); + communicator* mUbComm; - communicator* ub_comm_; - std::vector<UBBuffer> buffers_; - bool is_initialized_; - tensorrt_llm::runtime::WorldConfig world_config_; +protected: + std::vector<UBBuffer> mBuffers; + bool mIsInitialized; + tensorrt_llm::runtime::WorldConfig mWorldConfig; }; + +class NCCLHelper +{ +public: + NCCLHelper(); + ~NCCLHelper(); + + // Dynamic loading function type definition + using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int); + using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t); + + // Get function pointer for ncclCommWindowRegister + ncclCommWindowRegisterFunc getNCCLCommWindowRegister(); + + // Get function pointer for ncclMemAlloc + ncclMemAllocFunc getNCCLMemAlloc(); + + // Check if NCCL library is successfully loaded + bool isLoaded() const; + +private: + void loadNCCLLibrary(); + void* loadLibraryHandle(char const* libName); + void* getSymbolAddress(void* handle, char const* symbolName); + +#ifdef _WIN32 + HMODULE mLibraryHandle; +#else + void* mLibraryHandle; +#endif + + ncclCommWindowRegisterFunc mNCCLCommWindowRegister; + ncclMemAllocFunc mNCCLMemAlloc; + bool mIsLoaded; +}; + +class NCCLUserBufferAllocator : public UserBufferAllocator +{ +public: + void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override; + UBBuffer registerUBBuffer(size_t bytes) override; + + // Get shared NCCLHelper instance + static NCCLHelper& getNCCLHelper(); + +private: + std::shared_ptr<ncclComm_t> mComm; + static std::unique_ptr<NCCLHelper> mNCCLHelper; +}; + #else using communicator = void; #endif diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp index d7a3e69981..6d5f62b260 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp @@ -36,7 +36,7 @@ void ub_initialize(int tp_size) bool ub_is_initialized() { - return UserBufferAllocator::Instance().is_initialized(); + return UserBufferAllocator::Instance().isInitialized(); } UBBuffer ub_allocate(size_t bytes) diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu b/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu index 5ff106f8be..cc7b491bda 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu @@ -475,7 +475,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4_mc(PackedVec<Type>& vec, float SFScaleV // Local maximum value. #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { localMax = __hmax2(localMax, __habs2(vec.elts[i])); } @@ -530,10 +530,10 @@ __device__ uint32_t cvt_warp_fp16_to_fp4_mc(PackedVec<Type>& vec, float SFScaleV } // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { if constexpr (std::is_same_v<Type, half>) { @@ -650,9 +650,9 @@ __global__ void __launch_bounds__(MAX_THREADS) uint8_t* sf_out = nullptr; if (threadIdx.x % 8 == 0) { - sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt /* batchIdx */, - token_idx, threadIdx.x + g * loop_step0, std::nullopt /* numRows */, hidden_dim, - scale_out + scale_out_offset, FP4QuantizationSFLayout::SWIZZLED); + sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt /* batchIdx */, token_idx, + threadIdx.x + g * loop_step0, std::nullopt /* numRows */, hidden_dim / SF_VEC_SIZE, + scale_out + scale_out_offset, QuantizationSFLayout::SWIZZLED); } uint32_t val = cvt_warp_fp16_to_fp4_mc<DType, SF_VEC_SIZE>(valout, sf, sf_out); MULTIMEM_ST(val, mc_ptr_out + (out_lineoffset + line + g * loop_step0)); @@ -763,9 +763,9 @@ __global__ void __launch_bounds__(MAX_THREADS) (threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j)); i++; } - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt /* batchIdx */, - token_idx, threadIdx.x + g * loop_step0, std::nullopt /* numRows */, hidden_dim, - scale_out + scale_out_offset, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt /* batchIdx */, token_idx, + threadIdx.x + g * loop_step0, std::nullopt /* numRows */, hidden_dim / SF_VEC_SIZE, + scale_out + scale_out_offset, QuantizationSFLayout::SWIZZLED); mc_ptr_out[out_lineoffset + line + g * loop_step0] = cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(valout, sf, sf_out); } diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp index c636eec3d9..a1fcd3c01f 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp @@ -29,11 +29,14 @@ UserBuffersManager& UserBuffersManager::get_instance() return allocator; } -void UserBuffersManager::initialize( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size) +void UserBuffersManager::initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, + int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric) { std::lock_guard<std::mutex> lock(mutex_); tensorrt_llm::runtime::WorldConfig world_config(tp_size, pp_size, cp_size, rank, gpus_per_node); +#if ENABLE_MULTI_DEVICE + UserBufferAllocator::Instance().use_nccl_symmetric = use_nccl_symmetric; +#endif tensorrt_llm::runtime::ub::ub_initialize(world_config); TLLM_CHECK(tensorrt_llm::runtime::ub::ub_is_initialized()); buffer_size_ = buffer_size; @@ -95,10 +98,11 @@ tensorrt_llm::runtime::ub::communicator* UserBuffersManager::comm() return tensorrt_llm::runtime::ub::ub_comm(); } -void initialize_userbuffers_manager( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size) +void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, + int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric) { - UserBuffersManager::get_instance().initialize(tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size); + UserBuffersManager::get_instance().initialize( + tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size, use_nccl_symmetric); } } // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h index 7ec39db602..1b34f8e8a1 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h @@ -46,8 +46,9 @@ public: //! @param gpus_per_node The number of GPUs per node. //! @param buffer_size The size of the buffer to allocate. All buffers allocated by this manager will have this //! size. - void initialize( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size); + //! @param use_nccl_symmetric Whether to use NCCL symmetric communication. + void initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, + int64_t buffer_size, bool use_nccl_symmetric); //! @brief Create a UB tensor from the given shape, strides and data type. The function will choose available UB //! buffer or create a new one if no available buffer is found. @@ -75,7 +76,7 @@ private: int64_t buffer_size_; }; -void initialize_userbuffers_manager( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size); +void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, + int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric); } // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp b/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp index ea73b40035..0c6d35ec3f 100644 --- a/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp +++ b/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp @@ -385,6 +385,7 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff tllmRunnerParams.multiCtasKvCounterPtr = launchParams.semaphores; tllmRunnerParams.multiCtasKvScratchPtr = launchParams.scratch; + tllmRunnerParams.attentionSinksPtr = params.attention_sinks; tllmRunnerParams.cumSeqLensQPtr = cu_seqlens; tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(launchParams.cu_kv_seq_lens); tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(launchParams.bmm2_scale_ptr); diff --git a/cpp/tensorrt_llm/nanobind/CMakeLists.txt b/cpp/tensorrt_llm/nanobind/CMakeLists.txt index aa5b3cf45d..1ccb50a02b 100755 --- a/cpp/tensorrt_llm/nanobind/CMakeLists.txt +++ b/cpp/tensorrt_llm/nanobind/CMakeLists.txt @@ -17,6 +17,7 @@ set(SRCS testing/modelSpecBinding.cpp runtime/moeBindings.cpp userbuffers/bindings.cpp + thop/bindings.cpp ../runtime/ipcNvlsMemory.cu bindings.cpp) @@ -42,7 +43,9 @@ target_link_libraries( ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python - ${CUDA_NVML_LIB}) + CUDA::cuda_driver + ${CUDA_NVML_LIB} + th_common) target_compile_definitions( ${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE} PYBIND11_DETAILED_ERROR_MESSAGES=1) @@ -52,6 +55,6 @@ if(NOT WIN32) ${TRTLLM_NB_MODULE} PROPERTIES LINK_FLAGS - "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' -Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib/stubs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" + "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" ) endif() diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp index 357a502bcc..1944784eee 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp @@ -103,23 +103,21 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_ "__call__", [](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, - tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType, - DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, - tensorrt_llm::runtime::CudaStream const& runtimeStream, + nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, + runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream, tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, SizeType32 beamWidth) { OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt; - auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig, - worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState, - runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers); + auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] + = self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers, + decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers); return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs), std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; }, nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"), - nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"), - nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"), - nb::arg("max_sequence_length"), nb::arg("beam_width")) + nb::arg("logits_type"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), + nb::arg("runtime_stream"), nb::arg("decoder_stream"), nb::arg("max_sequence_length"), nb::arg("beam_width")) .def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; }); } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 56fdbf14e9..c170ca8101 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -248,7 +248,8 @@ void initBindings(nb::module_& m) } }) .def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) - .def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); + .def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) + .def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel); nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr()) .def( @@ -375,7 +376,8 @@ void initBindings(nb::module_& m) .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, nb::arg("manager")) .def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) - .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")); + .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")) + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager") .def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"), diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 74049eaf96..412698215a 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -325,7 +325,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0); nb::class_<tbk::KVCacheEventManager>(m, "KVCacheEventManager") - .def(nb::init<size_t>(), nb::arg("max_kv_event_entries")); + .def(nb::init<size_t, std::optional<SizeType32>, std::optional<SizeType32>, SizeType32>(), + nb::arg("max_kv_event_entries"), nb::arg("attention_dp_rank") = std::nullopt, + nb::arg("attention_dp_size") = std::nullopt, nb::arg("attention_dp_events_gather_period_ms") = 5); nb::class_<tbk::BaseKVCacheManager, PyKvCacheManager>(m, "BaseKVCacheManager") .def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, nb::arg("config"), diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index 89cfa72211..c951f967c2 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -39,6 +39,7 @@ #include "tensorrt_llm/nanobind/executor/bindings.h" #include "tensorrt_llm/nanobind/runtime/bindings.h" #include "tensorrt_llm/nanobind/testing/modelSpecBinding.h" +#include "tensorrt_llm/nanobind/thop/bindings.h" #include "tensorrt_llm/nanobind/userbuffers/bindings.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/cudaStream.h" @@ -124,9 +125,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m) auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings"); auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings"); auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings"); + auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings"); tensorrt_llm::nanobind::executor::initBindings(mExecutor); tensorrt_llm::nanobind::runtime::initBindingsEarly(mInternalRuntime); + tensorrt_llm::nanobind::thop::initBindings(mInternalThop); auto buildInfo = m.def_submodule("BuildInfo"); buildInfo.attr("ENABLE_MULTI_DEVICE") = nb::int_(ENABLE_MULTI_DEVICE); @@ -245,12 +248,17 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .def_prop_ro("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq) .def_prop_ro("has_nvfp4", &tc::QuantMode::hasNvfp4) .def_prop_ro("has_w4a8_mxfp4_fp8", &tc::QuantMode::hasW4a8Mxfp4Fp8) + + .def_prop_ro("has_w4a8_mxfp4_mxfp8", &tc::QuantMode::hasW4a8Mxfp4Mxfp8) + .def_prop_ro("has_w4a16_mxfp4", &tc::QuantMode::hasW4a16Mxfp4) + .def_prop_ro("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant) .def_static("from_description", &tc::QuantMode::fromDescription, nb::arg("quantize_weights"), nb::arg("quantize_activations"), nb::arg("per_token"), nb::arg("per_channel"), nb::arg("per_group"), nb::arg("use_int4_weights"), nb::arg("use_int8_kv_cache"), nb::arg("use_fp8_kv_kache"), nb::arg("use_fp8_qdq"), nb::arg("use_fp8_rowwise"), nb::arg("use_w4a8_qserve"), nb::arg("use_nvfp4"), - nb::arg("use_fp8_block_scales"), nb::arg("use_w4a8_mxfp4_fp8")) + nb::arg("use_fp8_block_scales"), nb::arg("use_w4a8_mxfp4_fp8"), nb::arg("use_w4a8_mxfp4_mxfp8"), + nb::arg("use_w4a16_mxfp4")) .def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, nb::arg("per_token") = false, nb::arg("per_channel") = false) .def_static("use_weight_only", &tc::QuantMode::useWeightOnly, nb::arg("use_int4_weights") = false, diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp index 5760d77fb4..505ecfca59 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -110,11 +110,12 @@ void initConfigBindings(nb::module_& m) return nb::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), - self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm()); + self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), + self.getAttentionDpEventsGatherPeriodMs()); }; auto kvCacheConfigSetstate = [](tle::KvCacheConfig& self, nb::tuple const& state) { - if (state.size() != 13) + if (state.size() != 14) { throw std::runtime_error("Invalid state!"); } @@ -123,20 +124,21 @@ void initConfigBindings(nb::module_& m) nb::cast<std::optional<float>>(state[4]), nb::cast<std::optional<size_t>>(state[5]), nb::cast<bool>(state[6]), nb::cast<std::optional<float>>(state[7]), nb::cast<std::optional<tle::RetentionPriority>>(state[8]), nb::cast<size_t>(state[9]), - nb::cast<bool>(state[10]), nb::cast<bool>(state[11]), nb::cast<bool>(state[12])); + nb::cast<bool>(state[10]), nb::cast<bool>(state[11]), nb::cast<bool>(state[12]), + nb::cast<SizeType32>(state[13])); }; nb::class_<tle::KvCacheConfig>(m, "KvCacheConfig") .def(nb::init<bool, std::optional<SizeType32> const&, std::optional<std::vector<SizeType32>> const&, std::optional<SizeType32> const&, std::optional<float> const&, std::optional<size_t> const&, bool, std::optional<float> const&, std::optional<tle::RetentionPriority>, size_t const&, bool, bool, bool, - std::optional<RuntimeDefaults> const&>(), + SizeType32, std::optional<RuntimeDefaults> const&>(), nb::arg("enable_block_reuse") = true, nb::arg("max_tokens") = nb::none(), nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none(), nb::arg("free_gpu_memory_fraction") = nb::none(), nb::arg("host_cache_size") = nb::none(), nb::arg("onboard_blocks") = true, nb::arg("cross_kv_cache_fraction") = nb::none(), nb::arg("secondary_offload_min_priority") = nb::none(), nb::arg("event_buffer_max_size") = 0, nb::kw_only(), nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("use_uvm") = false, - nb::arg("runtime_defaults") = nb::none()) + nb::arg("attention_dp_events_gather_period_ms") = 5, nb::arg("runtime_defaults") = nb::none()) .def_prop_rw( "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) .def_prop_rw("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) @@ -159,6 +161,8 @@ void initConfigBindings(nb::module_& m) .def_prop_rw("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse, &tle::KvCacheConfig::setCopyOnPartialReuse) .def_prop_rw("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm) + .def_prop_rw("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs, + &tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs) .def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults) .def("__getstate__", kvCacheConfigGetstate) .def("__setstate__", kvCacheConfigSetstate); diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index a3a8e087e3..a22a62bf80 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -220,10 +220,10 @@ void initBindings(nb::module_& m) nb::class_<tr::decoder::DecoderState>(m, "DecoderState") .def(nb::init<>()) - .def("setup", &tr::decoder::DecoderState::setup, nb::arg("max_batch_size"), nb::arg("max_beam_width"), + .def("setup", &tr::decoder::DecoderState::setup, nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window"), nb::arg("sink_token_length"), nb::arg("max_sequence_length"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")) - .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, nb::arg("max_batch_size"), + .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window"), nb::arg("buffer_manager")) .def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding, nb::arg("speculative_decoding_mode"), nb::arg("max_tokens_per_engine_step"), nb::arg("dtype"), @@ -277,7 +277,7 @@ void initBindings(nb::module_& m) nb::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched") .def(nb::init<tr::GptDecoderBatched::CudaStreamPtr>(), nb::arg("stream")) - .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_batch_size"), + .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config")) .def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input")) .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference) diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp new file mode 100644 index 0000000000..072df32b28 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include <nanobind/nanobind.h> +#include <nanobind/stl/optional.h> +#include <nanobind/stl/vector.h> +#include <tensorrt_llm/thop/attentionOp.h> +#include <torch/extension.h> + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::thop +{ + +void initBindings(nb::module_& m) +{ + m.def("attention", &torch_ext::attention, + // Parameters with default values using std::nullopt for optional arguments + nb::arg("q"), nb::arg("k") = std::nullopt, nb::arg("v") = std::nullopt, nb::arg("output"), + nb::arg("output_sf") = std::nullopt, nb::arg("out_dtype") = std::nullopt, nb::arg("workspace_") = std::nullopt, + nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("context_lengths"), + nb::arg("host_context_lengths"), nb::arg("host_request_types"), + nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt, + nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt, + nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt, + nb::arg("kv_scale_quant_orig") = std::nullopt, nb::arg("out_scale") = std::nullopt, + nb::arg("rotary_inv_freq") = std::nullopt, nb::arg("rotary_cos_sin") = std::nullopt, + nb::arg("latent_cache") = std::nullopt, nb::arg("q_pe") = std::nullopt, + nb::arg("block_ids_per_seq") = std::nullopt, nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"), + nb::arg("update_kv_cache"), nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"), + nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt, + nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"), + nb::arg("sink_token_length"), nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), + nb::arg("q_scaling"), nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), + nb::arg("rotary_embedding_base"), nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"), + nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"), + nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"), nb::arg("q_lora_rank") = std::nullopt, + nb::arg("kv_lora_rank") = std::nullopt, nb::arg("qk_nope_head_dim") = std::nullopt, + nb::arg("qk_rope_head_dim") = std::nullopt, nb::arg("v_head_dim") = std::nullopt, + nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt, + nb::arg("mla_context_paged_kv") = std::nullopt, nb::arg("mla_context_kv_cache_block_offsets") = std::nullopt, + nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt, + nb::arg("spec_decoding_bool_params"), nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation"); +} +} // namespace tensorrt_llm::nanobind::thop diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.h b/cpp/tensorrt_llm/nanobind/thop/bindings.h new file mode 100644 index 0000000000..534caf4c19 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include <nanobind/nanobind.h> + +namespace tensorrt_llm::nanobind::thop +{ + +void initBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::thop diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp index e2fab9044c..1e6e75404d 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp @@ -520,7 +520,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream); cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t), cudaMemcpyHostToDevice, stream); - mFMHARunner->run(fmhaParams); + mFmhaDispatcher->run(fmhaParams); if (iter != 0) { invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_, @@ -704,7 +704,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc } // Run the fmha kernel. - mFMHARunner->run(fmhaParams); + + // TODO: set it correctly for contiguous kv buffer (cross-attention). + fmhaParams.totalKvSeqLen = num_tokens; + + fmhaParams.cuKvSeqLenPtr = cu_seqlens; + fmhaParams.cuMaskRowsPtr = cu_seqlens; + fmhaParams.tileCounterPtr = fmha_tile_counter_ptr; + + fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr; + fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr; + fmhaParams.forceFp32Acc = mFMHAForceFP32Acc; + mFmhaDispatcher->run(fmhaParams); sync_check_cuda_error(stream); if (mSageAttn) { @@ -948,10 +959,14 @@ int BertAttentionPlugin::initialize() noexcept } // Load kernels from the pre-compiled cubins. - mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams)); + // The KV input data type. The default is same as dataType. + fmhaParams.dataTypeKv = data_type; + fmhaParams.headSizeV = mHeadSize; + // Load kernels from the pre-compiled cubins. + mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams)); // Fall back to unfused MHA kernels if not supported. - mEnableContextFMHA = mFMHARunner->isFmhaSupported(); + mEnableContextFMHA = mFmhaDispatcher->isSupported(); } #if ENABLE_MULTI_DEVICE diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h index 0c5fdc15b6..2eb39086a0 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h @@ -18,7 +18,7 @@ #include "tensorrt_llm/common/cublasMMWrapper.h" #include "tensorrt_llm/common/quantization.h" -#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h" +#include "tensorrt_llm/kernels/fmhaDispatcher.h" #include "tensorrt_llm/kernels/gptKernels.h" #include "tensorrt_llm/plugins/common/plugin.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -114,7 +114,7 @@ private: cudaStream_t mNcclStream; // The default copy constructor will leave them as nullptr. clone() shall initialize it. - UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner; + UniqPtrWNullCopy<tensorrt_llm::kernels::FmhaDispatcher> mFmhaDispatcher; UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper; }; diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index 6db0e4a382..189e23b8ac 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -334,12 +334,13 @@ void MixtureOfExpertsPlugin::init() static_cast<int>(mType), static_cast<int>(mWeightType), static_cast<int>(mOutputType)); } - mMOERunner->use_deterministic_hopper_reduce_ = mExpertsPerToken > 2 && mUseDeterministicKernels; + mMOERunner->use_fused_finalize_ + = (mExpertsPerToken < 3 || !mUseDeterministicKernels) && !getEnvMOEDisableFinalizeFusion(); mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize, - mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_}; + mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_}; mGemmId2 = GemmIDMoe{2, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize, - mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_}; + mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_}; mGemmProfiler->setMaxProfileM(16384 * mNumExperts / mExpertsPerToken); if (hasLora()) @@ -957,23 +958,25 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, MoeMinLatencyParams min_latency_params{}; mMOERunner->setTactic(gemm1, gemm2); #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, + mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true, static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]), hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr, - inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType, - inputs[getExpertWeights2Index()], hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, - mExpertHiddenSize, mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace), + inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, + ActivationParams(mActivationType), inputs[getExpertWeights2Index()], + hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, mExpertHiddenSize, + mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace), // Outputs outputs[getOutputTensorIndex()], static_cast<int*>(workspace.src_to_dest_map), mParallelismConfig, /*enable_alltoall=*/false, hasLora(), lora_params, /*use_deepseek_fp8_block_scale=*/false, /*min_latency_mode=*/false, min_latency_params, stream); #else - mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, + mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true, static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]), hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr, - inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType, - inputs[getExpertWeights2Index()], hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, - mExpertHiddenSize, mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace), + inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, + ActivationParams(mActivationType), inputs[getExpertWeights2Index()], + hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, mExpertHiddenSize, + mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace), // Outputs outputs[getOutputTensorIndex()], static_cast<int*>(workspace.src_to_dest_map), mParallelismConfig, hasLora(), lora_params, /*use_deepseek_fp8_block_scale=*/false, diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h index d7b804fab0..cd3aaf52c2 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h @@ -44,6 +44,7 @@ using MoeMinLatencyParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MoeMinLatencyPar using MOEParallelismConfig = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MOEParallelismConfig; using QuantParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::QuantParams; using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; +using ActivationParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams; using TmaWarpSpecializedGroupedGemmInput = CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation; diff --git a/cpp/tensorrt_llm/plugins/quantizeToFP4Plugin/quantizeToFP4Plugin.cpp b/cpp/tensorrt_llm/plugins/quantizeToFP4Plugin/quantizeToFP4Plugin.cpp index b75e7cb066..b7454cce4f 100644 --- a/cpp/tensorrt_llm/plugins/quantizeToFP4Plugin/quantizeToFP4Plugin.cpp +++ b/cpp/tensorrt_llm/plugins/quantizeToFP4Plugin/quantizeToFP4Plugin.cpp @@ -160,7 +160,7 @@ int QuantizeToFP4Plugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, case DataType::kHALF: { auto input = reinterpret_cast<half const*>(inputs[0]); - invokeFP4Quantization(m, n, input, SFScale, output, SFoutput, false, FP4QuantizationSFLayout::SWIZZLED, + invokeFP4Quantization(1, m, n, input, SFScale, output, SFoutput, false, QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, stream); break; } @@ -168,7 +168,7 @@ int QuantizeToFP4Plugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, case DataType::kBF16: { auto input = reinterpret_cast<__nv_bfloat16 const*>(inputs[0]); - invokeFP4Quantization(m, n, input, SFScale, output, SFoutput, false, FP4QuantizationSFLayout::SWIZZLED, + invokeFP4Quantization(1, m, n, input, SFScale, output, SFoutput, false, QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, stream); break; } @@ -176,7 +176,7 @@ int QuantizeToFP4Plugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, case DataType::kFP8: { auto input = reinterpret_cast<__nv_fp8_e4m3 const*>(inputs[0]); - invokeFP4Quantization(m, n, input, SFScale, output, SFoutput, false, FP4QuantizationSFLayout::SWIZZLED, + invokeFP4Quantization(1, m, n, input, SFScale, output, SFoutput, false, QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, stream); break; } diff --git a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp index 4235c808a4..b0f9739eaa 100644 --- a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp @@ -398,7 +398,7 @@ IPluginV2* SmoothQuantGemmPluginCreator::createPlugin(char const* name, PluginFi // Create plugin profiler with shared tactics map auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false); QuantMode quantMode = QuantMode::fromDescription(true, true, perTokenScaling, perChannelScaling, false, false, - false, false, false, false, false, false, false, false); + false, false, false, false, false, false, false, false, false, false); auto* obj = new SmoothQuantGemmPlugin(quantMode, type, pluginProfiler); obj->setPluginNamespace(mNamespace.c_str()); return obj; diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt index b4809d5135..91b5ebf548 100755 --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -18,6 +18,7 @@ set(SRCS runtime/moeBindings.cpp userbuffers/bindings.cpp ../runtime/ipcNvlsMemory.cu + thop/bindings.cpp bindings.cpp) include_directories(${PROJECT_SOURCE_DIR}/include) @@ -43,7 +44,9 @@ target_link_libraries( ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python - ${CUDA_NVML_LIB}) + CUDA::cuda_driver + ${CUDA_NVML_LIB} + th_common) target_compile_definitions( ${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE} PYBIND11_DETAILED_ERROR_MESSAGES=1) @@ -53,6 +56,6 @@ if(NOT WIN32) ${TRTLLM_PYBIND_MODULE} PROPERTIES LINK_FLAGS - "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' -Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib/stubs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" + "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" ) endif() diff --git a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp index f098398b62..8f0cc3315c 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp @@ -105,23 +105,21 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod "__call__", [](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, - tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType, - DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, - tensorrt_llm::runtime::CudaStream const& runtimeStream, + nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, + runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream, tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, SizeType32 beamWidth) { OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt; - auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig, - worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState, - runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers); + auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] + = self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers, + decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers); return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs), std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; }, py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"), - py::arg("buffer_manager"), py::arg("logits_type"), py::arg("decoder_input_buffers"), - py::arg("decoder_state"), py::arg("runtime_stream"), py::arg("decoder_stream"), - py::arg("max_sequence_length"), py::arg("beam_width")) + py::arg("logits_type"), py::arg("decoder_input_buffers"), py::arg("decoder_state"), + py::arg("runtime_stream"), py::arg("decoder_stream"), py::arg("max_sequence_length"), py::arg("beam_width")) .def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; }); } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index f0d74f4f99..5cf036e76c 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -253,7 +253,8 @@ void initBindings(pybind11::module_& m) } }) .def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) - .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); + .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) + .def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel); py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr()) .def(py::init<>( @@ -381,7 +382,8 @@ void initBindings(pybind11::module_& m) .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager")) .def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) - .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter")); + .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter")) + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager") .def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"), diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 255b0f8efa..54835e81d7 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -321,7 +321,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def_static("hash", &tbk::BlockKeyHasher::hash, py::arg("block_key"), py::arg("parent_hash") = 0); py::class_<tbk::KVCacheEventManager, std::shared_ptr<tbk::KVCacheEventManager>>(m, "KVCacheEventManager") - .def(py::init<size_t>(), py::arg("max_kv_event_entries")); + .def(py::init<size_t, std::optional<SizeType32>, std::optional<SizeType32>, SizeType32>(), + py::arg("max_kv_event_entries"), py::arg("attention_dp_rank") = std::nullopt, + py::arg("attention_dp_size") = std::nullopt, py::arg("attention_dp_events_gather_period_ms") = 5); py::classh<tbk::BaseKVCacheManager, PyKvCacheManager>(m, "BaseKVCacheManager") .def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, py::arg("config"), diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 216baaa362..cdc9736db0 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -33,6 +33,7 @@ #include "tensorrt_llm/pybind/executor/bindings.h" #include "tensorrt_llm/pybind/runtime/bindings.h" #include "tensorrt_llm/pybind/testing/modelSpecBinding.h" +#include "tensorrt_llm/pybind/thop/bindings.h" #include "tensorrt_llm/pybind/userbuffers/bindings.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/cudaStream.h" @@ -116,9 +117,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings"); auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings"); auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings"); + auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings"); tensorrt_llm::pybind::executor::initBindings(mExecutor); tensorrt_llm::pybind::runtime::initBindingsEarly(mInternalRuntime); + tensorrt_llm::pybind::thop::initBindings(mInternalThop); auto buildInfo = m.def_submodule("BuildInfo"); buildInfo.attr("ENABLE_MULTI_DEVICE") = py::int_(ENABLE_MULTI_DEVICE); @@ -237,12 +240,15 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_property_readonly("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq) .def_property_readonly("has_nvfp4", &tc::QuantMode::hasNvfp4) .def_property_readonly("has_w4a8_mxfp4_fp8", &tc::QuantMode::hasW4a8Mxfp4Fp8) + .def_property_readonly("has_w4a8_mxfp4_mxfp8", &tc::QuantMode::hasW4a8Mxfp4Mxfp8) + .def_property_readonly("has_w4a16_mxfp4", &tc::QuantMode::hasW4a16Mxfp4) .def_property_readonly("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant) .def_static("from_description", &tc::QuantMode::fromDescription, py::arg("quantize_weights"), py::arg("quantize_activations"), py::arg("per_token"), py::arg("per_channel"), py::arg("per_group"), py::arg("use_int4_weights"), py::arg("use_int8_kv_cache"), py::arg("use_fp8_kv_kache"), py::arg("use_fp8_qdq"), py::arg("use_fp8_rowwise"), py::arg("use_w4a8_qserve"), py::arg("use_nvfp4"), - py::arg("use_fp8_block_scales"), py::arg("use_w4a8_mxfp4_fp8")) + py::arg("use_fp8_block_scales"), py::arg("use_w4a8_mxfp4_fp8"), py::arg("use_w4a8_mxfp4_mxfp8"), + py::arg("use_w4a16_mxfp4")) .def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, py::arg("per_token") = false, py::arg("per_channel") = false) .def_static("use_weight_only", &tc::QuantMode::useWeightOnly, py::arg("use_int4_weights") = false, diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index a8f6aaef73..bbb843bedb 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -240,7 +240,8 @@ void initBindings(pybind11::module_& m) py::class_<tle::KVCacheEvent>(executor_kv_cache, "KVCacheEvent") .def_readonly("event_id", &tle::KVCacheEvent::eventId) .def_readonly("data", &tle::KVCacheEvent::data) - .def_readonly("window_size", &tle::KVCacheEvent::windowSize); + .def_readonly("window_size", &tle::KVCacheEvent::windowSize) + .def_readonly("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank); py::class_<tle::KVCacheEventManager, std::shared_ptr<tle::KVCacheEventManager>>( executor_kv_cache, "KVCacheEventManager") diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index ccbb21aab2..0e279a3e47 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -103,11 +103,12 @@ void initConfigBindings(pybind11::module_& m) return py::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), - self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm()); + self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), + self.getAttentionDpEventsGatherPeriodMs()); }; auto kvCacheConfigSetstate = [](py::tuple const& state) { - if (state.size() != 13) + if (state.size() != 14) { throw std::runtime_error("Invalid state!"); } @@ -115,20 +116,21 @@ void initConfigBindings(pybind11::module_& m) state[2].cast<std::optional<std::vector<SizeType32>>>(), state[3].cast<std::optional<SizeType32>>(), state[4].cast<std::optional<float>>(), state[5].cast<std::optional<size_t>>(), state[6].cast<bool>(), state[7].cast<std::optional<float>>(), state[8].cast<std::optional<tle::RetentionPriority>>(), - state[9].cast<size_t>(), state[10].cast<bool>(), state[11].cast<bool>(), state[12].cast<bool>()); + state[9].cast<size_t>(), state[10].cast<bool>(), state[11].cast<bool>(), state[12].cast<bool>(), + state[13].cast<SizeType32>()); }; py::class_<tle::KvCacheConfig>(m, "KvCacheConfig") .def(py::init<bool, std::optional<SizeType32> const&, std::optional<std::vector<SizeType32>> const&, std::optional<SizeType32> const&, std::optional<float> const&, std::optional<size_t> const&, bool, std::optional<float> const&, std::optional<tle::RetentionPriority>, size_t const&, bool, bool, bool, - std::optional<RuntimeDefaults> const&>(), + SizeType32, std::optional<RuntimeDefaults> const&>(), py::arg("enable_block_reuse") = true, py::arg("max_tokens") = py::none(), py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(), py::arg("free_gpu_memory_fraction") = py::none(), py::arg("host_cache_size") = py::none(), py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(), py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::kw_only(), py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("use_uvm") = false, - py::arg("runtime_defaults") = py::none()) + py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none()) .def_property( "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) .def_property("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) @@ -151,6 +153,8 @@ void initConfigBindings(pybind11::module_& m) .def_property("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse, &tle::KvCacheConfig::setCopyOnPartialReuse) .def_property("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm) + .def_property("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs, + &tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs) .def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults) .def(py::pickle(kvCacheConfigGetstate, kvCacheConfigSetstate)); diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 432f7e6b13..17aa48ef30 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -312,10 +312,10 @@ void initBindings(pybind11::module_& m) py::class_<tr::decoder::DecoderState>(m, "DecoderState") .def(py::init<>()) - .def("setup", &tr::decoder::DecoderState::setup, py::arg("max_batch_size"), py::arg("max_beam_width"), + .def("setup", &tr::decoder::DecoderState::setup, py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("max_attention_window"), py::arg("sink_token_length"), py::arg("max_sequence_length"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager")) - .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, py::arg("max_batch_size"), + .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("max_attention_window"), py::arg("buffer_manager")) .def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding, py::arg("speculative_decoding_mode"), py::arg("max_tokens_per_engine_step"), py::arg("dtype"), @@ -371,7 +371,7 @@ void initBindings(pybind11::module_& m) py::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched") .def(py::init<tr::GptDecoderBatched::CudaStreamPtr>(), py::arg("stream")) - .def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_batch_size"), + .def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config")) .def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("input")) .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, py::return_value_policy::reference) @@ -456,7 +456,8 @@ void initBindings(pybind11::module_& m) .value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO) .value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB) .value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT) - .value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT); + .value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT) + .value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC); // Initialize MoeLoadBalancer bindings initMoeBindings(m); diff --git a/cpp/tensorrt_llm/pybind/thop/bindings.cpp b/cpp/tensorrt_llm/pybind/thop/bindings.cpp new file mode 100644 index 0000000000..6575043625 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/thop/bindings.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include <pybind11/functional.h> +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <tensorrt_llm/thop/attentionOp.h> +#include <torch/extension.h> + +namespace py = pybind11; + +namespace tensorrt_llm::pybind::thop +{ + +void initBindings(pybind11::module_& m) +{ + m.def("attention", &torch_ext::attention, + // Parameters with default values using std::nullopt for optional arguments + py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"), + py::arg("output_sf") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("workspace_") = std::nullopt, + py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("context_lengths"), + py::arg("host_context_lengths"), py::arg("host_request_types"), + py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt, + py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt, + py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt, + py::arg("kv_scale_quant_orig") = std::nullopt, py::arg("out_scale") = std::nullopt, + py::arg("rotary_inv_freq") = std::nullopt, py::arg("rotary_cos_sin") = std::nullopt, + py::arg("latent_cache") = std::nullopt, py::arg("q_pe") = std::nullopt, + py::arg("block_ids_per_seq") = std::nullopt, py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), + py::arg("update_kv_cache"), py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), + py::arg("num_kv_heads"), py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, + py::arg("max_num_requests"), py::arg("max_context_length"), py::arg("attention_window_size"), + py::arg("sink_token_length"), py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), + py::arg("q_scaling"), py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), + py::arg("rotary_embedding_base"), py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"), + py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"), + py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"), py::arg("q_lora_rank") = std::nullopt, + py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt, + py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt, + py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt, + py::arg("mla_context_paged_kv") = std::nullopt, py::arg("mla_context_kv_cache_block_offsets") = std::nullopt, + py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt, + py::arg("spec_decoding_bool_params"), py::arg("spec_decoding_tensor_params"), "Multi-head attention operation"); +} +} // namespace tensorrt_llm::pybind::thop diff --git a/cpp/tensorrt_llm/pybind/thop/bindings.h b/cpp/tensorrt_llm/pybind/thop/bindings.h new file mode 100644 index 0000000000..08d429b850 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/thop/bindings.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/pybind/common/customCasters.h" +#include <pybind11/pybind11.h> + +namespace tensorrt_llm::pybind::thop +{ + +void initBindings(pybind11::module_& m); + +} // namespace tensorrt_llm::pybind::thop diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index 740aa6e9cf..abccbe60a1 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -53,13 +53,13 @@ DecoderState::DecoderState() TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void DecoderState::setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, +void DecoderState::setup(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); setupBuffers(dtype, bufferManager); - reshapeBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength, modelConfig, + reshapeBuffers(maxNumSequences, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength, modelConfig, worldConfig, bufferManager); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -90,7 +90,6 @@ void DecoderState::setupBuffers(nvinfer1::DataType dtype, BufferManager const& b dOutput->lengths = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dOutput->finishedSum = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); - // we don't need dOutput->lengths because lengths are passed from outside dOutput->cumLogProbs = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->logProbs = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->beamHypotheses.empty(bufferManager); @@ -197,7 +196,7 @@ void DecoderState::setupSpeculativeDecodingBuffers( TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void DecoderState::reshapeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, +void DecoderState::reshapeBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager) { @@ -205,75 +204,76 @@ void DecoderState::reshapeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWid auto const& stream = bufferManager.getStream(); - TLLM_CHECK(maxBatchSize > 0); + TLLM_CHECK(maxNumSequences > 0); TLLM_CHECK(maxBeamWidth > 0); TLLM_CHECK(mMaxDecodingEngineTokens > 0); TLLM_CHECK(maxSequenceLength > 0); - mMaxBatchSize = maxBatchSize; + mMaxNumSequences = maxNumSequences; mMaxBeamWidth = maxBeamWidth; mMaxSequenceLength = maxSequenceLength; mNumDecodingEngineTokens.clear(); - mNumDecodingEngineTokens.resize(mMaxBatchSize, 0); + mNumDecodingEngineTokens.resize(mMaxNumSequences, 0); // setup input auto& dInput = *mJointDecodingInput; dInput.maxLength = mMaxSequenceLength; dInput.maxAttentionWindow = maxAttentionWindow; dInput.sinkTokenLength = sinkTokenLength; - dInput.stopWordsLists.resize(mMaxBatchSize); - dInput.badWordsLists.resize(mMaxBatchSize); + dInput.stopWordsLists.resize(mMaxNumSequences); + dInput.badWordsLists.resize(mMaxNumSequences); - auto const maxBatchSizeShape = ITensor::makeShape({mMaxBatchSize}); - auto const maxBatchSizeXmaxBeamWidthShape = ITensor::makeShape({mMaxBatchSize, mMaxBeamWidth}); + auto const maxNumSequencesShape = ITensor::makeShape({mMaxNumSequences}); + auto const maxNumSequencesXmaxBeamWidthShape = ITensor::makeShape({mMaxNumSequences, mMaxBeamWidth}); - const_cast<ITensor&>(*dInput.endIds).reshape(maxBatchSizeShape); + const_cast<ITensor&>(*dInput.endIds).reshape(maxNumSequencesShape); auto& sequenceLimitLength = const_cast<ITensor&>(*dInput.sequenceLimitLength); - sequenceLimitLength.reshape(maxBatchSizeShape); + sequenceLimitLength.reshape(maxNumSequencesShape); kernels::invokeFill(sequenceLimitLength, mMaxSequenceLength, stream); auto& inputLengths = const_cast<ITensor&>(*dInput.lengths); - inputLengths.reshape(maxBatchSizeXmaxBeamWidthShape); + inputLengths.reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(inputLengths); dInput.beamWidths.clear(); - dInput.beamWidths.resize(mMaxBatchSize, 0); + dInput.beamWidths.resize(mMaxNumSequences, 0); - auto const maxTotalTokensShape = ITensor::makeShape({mMaxBatchSize, mMaxBeamWidth, mMaxSequenceLength}); + auto const maxTotalTokensShape = ITensor::makeShape({mMaxNumSequences, mMaxBeamWidth, mMaxSequenceLength}); // setup output auto& dOutput = *mJointDecodingOutput; dOutput.ids->reshape(maxTotalTokensShape); - dOutput.finishReasons->reshape(maxBatchSizeXmaxBeamWidthShape); + auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxNumSequences, mMaxBeamWidth}); + + dOutput.finishReasons->reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(*dOutput.finishReasons); dOutput.parentIds->reshape(maxTotalTokensShape); - dOutput.lengths->reshape(maxBatchSizeXmaxBeamWidthShape); + dOutput.lengths->reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(*dOutput.lengths); - dOutput.finishedSum->reshape(maxBatchSizeShape); + dOutput.finishedSum->reshape(maxNumSequencesShape); bufferManager.setZero(*dOutput.finishedSum); - auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxBatchSize, mMaxBeamWidth}); dOutput.newTokensSteps->reshape(maxNewTokensShape); bufferManager.setZero(*dOutput.newTokensSteps); - dOutput.cumLogProbs->reshape(maxBatchSizeXmaxBeamWidthShape); + dOutput.cumLogProbs->reshape(maxNumSequencesXmaxBeamWidthShape); bufferManager.setZero(*dOutput.cumLogProbs); dOutput.logProbs->reshape(maxTotalTokensShape); bufferManager.setZero(*dOutput.logProbs); - dOutput.logProbsTiled->reshape(ITensor::makeShape({mMaxSequenceLength, mMaxBatchSize, mMaxBeamWidth})); + dOutput.logProbsTiled->reshape(ITensor::makeShape({mMaxSequenceLength, mMaxNumSequences, mMaxBeamWidth})); bufferManager.setZero(*dOutput.logProbsTiled); if (mMaxBeamWidth > 1) { - dOutput.beamHypotheses.reshape(mMaxBatchSize, mMaxBeamWidth, mMaxSequenceLength); + dOutput.beamHypotheses.reshape(mMaxNumSequences, mMaxBeamWidth, mMaxSequenceLength); mBeamSearchBuffers->reshape(mMaxBeamWidth, mMaxSequenceLength); - reshapeCacheIndirectionBuffers(mMaxBatchSize, mMaxBeamWidth, maxAttentionWindow); + reshapeCacheIndirectionBuffers(mMaxNumSequences, mMaxBeamWidth, maxAttentionWindow); dOutput.gatheredIds->reshape(maxTotalTokensShape); } @@ -285,21 +285,21 @@ void DecoderState::reshapeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWid auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); const_cast<ITensor&>(*dInput.embeddingBias) - .reshape(ITensor::makeShape({mMaxBatchSize, static_cast<SizeType32>(vocabSizePadded)})); - const_cast<ITensor&>(*dInput.badWordsPtrs).reshape(maxBatchSizeShape); - const_cast<ITensor&>(*dInput.badWordsLens).reshape(maxBatchSizeShape); - const_cast<ITensor&>(*dInput.stopWordsPtrs).reshape(maxBatchSizeShape); - const_cast<ITensor&>(*dInput.stopWordsLens).reshape(maxBatchSizeShape); + .reshape(ITensor::makeShape({mMaxNumSequences, static_cast<SizeType32>(vocabSizePadded)})); + const_cast<ITensor&>(*dInput.badWordsPtrs).reshape(maxNumSequencesShape); + const_cast<ITensor&>(*dInput.badWordsLens).reshape(maxNumSequencesShape); + const_cast<ITensor&>(*dInput.stopWordsPtrs).reshape(maxNumSequencesShape); + const_cast<ITensor&>(*dInput.stopWordsLens).reshape(maxNumSequencesShape); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void DecoderState::setupCacheIndirection( - SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, BufferManager const& bufferManager) +void DecoderState::setupCacheIndirection(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, + SizeType32 maxAttentionWindow, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); setupCacheIndirectionBuffers(bufferManager); - reshapeCacheIndirectionBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow); + reshapeCacheIndirectionBuffers(maxNumSequences, maxBeamWidth, maxAttentionWindow); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -311,12 +311,12 @@ void DecoderState::setupCacheIndirectionBuffers(BufferManager const& bufferManag } void DecoderState::reshapeCacheIndirectionBuffers( - SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow) + SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow) { mJointDecodingInput->cacheIndirection->reshape( - ITensor::makeShape({maxBatchSize, maxBeamWidth, maxAttentionWindow})); + ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow})); mJointDecodingOutput->cacheIndirection->reshape( - ITensor::makeShape({maxBatchSize, maxBeamWidth, maxAttentionWindow})); + ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow})); } void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode const& speculativeDecodingMode, @@ -337,7 +337,7 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con "or > 1 for any speculative decoding mode.", mMaxDecodingEngineTokens); - auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxBatchSize, mMaxBeamWidth}); + auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxNumSequences, mMaxBeamWidth}); dOutput.newTokensSteps->reshape(maxNewTokensShape); bufferManager.setZero(*dOutput.newTokensSteps); @@ -356,42 +356,42 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con return; } - auto const maxBatchSizeShape = ITensor::makeShape({mMaxBatchSize}); + auto const maxNumSequencesShape = ITensor::makeShape({mMaxNumSequences}); if (speculativeDecodingMode.isDraftTokensExternal()) { auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); auto const probsShape = ITensor::makeShape( - {mMaxBatchSize, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast<SizeType32>(vocabSizePadded)}); + {mMaxNumSequences, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast<SizeType32>(vocabSizePadded)}); dInput.externalDraftTokensInputs->draftProbs->reshape(probsShape); dInput.externalDraftTokensInputs->targetProbs->reshape(probsShape); dInput.externalDraftTokensInputs->draftLogits->reshape( - ITensor::makeShape({mMaxBatchSize, mMaxDecodingEngineTokens, static_cast<SizeType32>(vocabSizePadded)})); + ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens, static_cast<SizeType32>(vocabSizePadded)})); dInput.externalDraftTokensInputs->draftTokenIds->reshape( - ITensor::makeShape({mMaxBatchSize, mMaxDecodingEngineTokens})); - dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxBatchSizeShape); - dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxBatchSizeShape); - dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxBatchSizeShape); - dInput.externalDraftTokensInputs->useDraftLogitsHost->reshape(maxBatchSizeShape); + ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens})); + dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxNumSequencesShape); + dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxNumSequencesShape); + dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxNumSequencesShape); + dInput.externalDraftTokensInputs->useDraftLogitsHost->reshape(maxNumSequencesShape); } if (speculativeDecodingMode.isMedusa()) { auto const speculativeDecodingModule = modelConfig.getSpeculativeDecodingModulePtr(); auto& medusaPaths = const_cast<ITensor&>(*dInput.medusaInputs->medusaPaths); - medusaPaths.reshape(ITensor::makeShape({mMaxBatchSize, speculativeDecodingModule->getMaxDecodingTokens(), + medusaPaths.reshape(ITensor::makeShape({mMaxNumSequences, speculativeDecodingModule->getMaxDecodingTokens(), speculativeDecodingModule->getMaxPathLen()})); bufferManager.setMem(medusaPaths, -1); auto& medusaTreeIds = const_cast<ITensor&>(*dInput.medusaInputs->medusaTreeIds); medusaTreeIds.reshape( - ITensor::makeShape({mMaxBatchSize, speculativeDecodingModule->getMaxDecodingDraftTokens()})); + ITensor::makeShape({mMaxNumSequences, speculativeDecodingModule->getMaxDecodingDraftTokens()})); bufferManager.setZero(medusaTreeIds); auto& curTokensPerStep = const_cast<ITensor&>(*dInput.medusaInputs->medusaCurTokensPerStep); auto& targetTokensPerStep = const_cast<ITensor&>(*dInput.medusaInputs->medusaTargetTokensPerStep); - curTokensPerStep.reshape(maxBatchSizeShape); - targetTokensPerStep.reshape(maxBatchSizeShape); + curTokensPerStep.reshape(maxNumSequencesShape); + targetTokensPerStep.reshape(maxNumSequencesShape); bufferManager.setZero(curTokensPerStep); bufferManager.setZero(targetTokensPerStep); } @@ -399,37 +399,37 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con if (speculativeDecodingMode.predictsDraftTokens()) { dOutput.speculativeDecodingOutputs->nextDraftTokens->reshape( - ITensor::makeShape({mMaxBatchSize, mMaxDecodingEngineTokens - 1})); + ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens - 1})); if (speculativeDecodingMode.variableDraftLength()) { - dOutput.speculativeDecodingOutputs->nextDraftTokensLen->reshape(maxBatchSizeShape); - dOutput.speculativeDecodingOutputs->prevDraftTokensLen->reshape(maxBatchSizeShape); + dOutput.speculativeDecodingOutputs->nextDraftTokensLen->reshape(maxNumSequencesShape); + dOutput.speculativeDecodingOutputs->prevDraftTokensLen->reshape(maxNumSequencesShape); } } if (speculativeDecodingMode.needsKVCacheRewind()) { auto const speculativeDecodingModule = modelConfig.getSpeculativeDecodingModulePtr(); - dOutput.speculativeDecodingOutputs->acceptedTokensLen->reshape(maxBatchSizeShape); - dOutput.speculativeDecodingOutputs->acceptedLengthsCumSum->reshape(ITensor::makeShape({mMaxBatchSize + 1})); + dOutput.speculativeDecodingOutputs->acceptedTokensLen->reshape(maxNumSequencesShape); + dOutput.speculativeDecodingOutputs->acceptedLengthsCumSum->reshape(ITensor::makeShape({mMaxNumSequences + 1})); dOutput.speculativeDecodingOutputs->pathsOffsets->reshape( - ITensor::makeShape({mMaxBatchSize * speculativeDecodingModule->getMaxDraftPathLen()})); + ITensor::makeShape({mMaxNumSequences * speculativeDecodingModule->getMaxDraftPathLen()})); } if (speculativeDecodingMode.isExplicitDraftTokens()) { mJointDecodingOutput->explicitDraftTokensBuffers = runtime::ExplicitDraftTokensBuffers::Inputs(); mJointDecodingOutput->explicitDraftTokensBuffers->create( - mMaxBatchSize, bufferManager, modelConfig, worldConfig); + mMaxNumSequences, bufferManager, modelConfig, worldConfig); } else if (speculativeDecodingMode.isEagle()) { mJointDecodingOutput->eagleBuffers = runtime::EagleBuffers::Inputs(); - mJointDecodingOutput->eagleBuffers->create(mMaxBatchSize, bufferManager, modelConfig, worldConfig); + mJointDecodingOutput->eagleBuffers->create(mMaxNumSequences, bufferManager, modelConfig, worldConfig); } else if (speculativeDecodingMode.isLookaheadDecoding()) { mJointDecodingOutput->lookaheadOutputs - = runtime::LookaheadDecodingBuffers(mMaxBatchSize, mMaxDecodingEngineTokens, bufferManager); + = runtime::LookaheadDecodingBuffers(mMaxNumSequences, mMaxDecodingEngineTokens, bufferManager); mJointDecodingInput->lookaheadInputs->tokensPerStep = mJointDecodingOutput->lookaheadOutputs->generationLengths; } @@ -446,7 +446,7 @@ void DecoderState::disableLookahead(RequestVector const& genRequests) mMaxDecodingDecoderTokens = 1; mJointDecodingInput->lookaheadInputs.reset(); - auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxBatchSize, mMaxBeamWidth}); + auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxNumSequences, mMaxBeamWidth}); mJointDecodingOutput->newTokensSteps->reshape(maxNewTokensShape); for (auto const& llmReq : genRequests) @@ -555,9 +555,9 @@ TensorPtr DecoderState::getAcceptedPackedPaths() const return mJointDecodingOutput->speculativeDecodingOutputs->pathsOffsets; } -SizeType32 DecoderState::getMaxBatchSize() const +SizeType32 DecoderState::getMaxNumSequences() const { - return mMaxBatchSize; + return mMaxNumSequences; } SizeType32 DecoderState::getMaxBeamWidth() const @@ -607,13 +607,15 @@ std::vector<SizeType32> const& DecoderState::getNumDecodingEngineTokens() const SizeType32 DecoderState::getNumDecodingEngineTokens(SizeType32 batchIdx) const { - TLLM_CHECK_WITH_INFO(batchIdx < mMaxBatchSize, "Batch index %d out of bounds (max %d)", batchIdx, mMaxBatchSize); + TLLM_CHECK_WITH_INFO( + batchIdx < mMaxNumSequences, "Batch index %d out of bounds (max %d)", batchIdx, mMaxNumSequences); return mNumDecodingEngineTokens[batchIdx]; } void DecoderState::setNumDecodingEngineTokens(SizeType32 batchIdx, SizeType32 numTokens) { - TLLM_CHECK_WITH_INFO(batchIdx < mMaxBatchSize, "Batch index %d out of bounds (max %d)", batchIdx, mMaxBatchSize); + TLLM_CHECK_WITH_INFO( + batchIdx < mMaxNumSequences, "Batch index %d out of bounds (max %d)", batchIdx, mMaxNumSequences); mNumDecodingEngineTokens[batchIdx] = numTokens; } @@ -642,6 +644,11 @@ void DecoderState::setGenerationSteps(std::vector<SizeType32> const& generationS mJointDecodingInput->generationSteps = generationSteps; } +void DecoderState::setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth) +{ + mJointDecodingInput->beamWidths.at(batchIdx) = beamWidth; +} + DecodingInput& DecoderState::getJointDecodingInput() const { return *mJointDecodingInput; diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index a9805881e7..610eae1138 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -36,11 +36,11 @@ using TensorConstPtr = ITensor::SharedConstPtr; using TensorPtr = ITensor::SharedPtr; template <typename T> -GptDecoder<T>::GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, +GptDecoder<T>::GptDecoder(executor::DecodingMode const& mode, size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream, std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule) : mManager{std::make_shared<BufferManager>(stream)} - , mMaxBatchSize(maxBatchSize) + , mMaxNumSequences(maxNumSequences) , mVocabSize(vocabSize) , mVocabSizePadded(vocabSizePadded) , mDecodingMode{mode} @@ -48,7 +48,7 @@ GptDecoder<T>::GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSiz TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const decodingDomain = tensorrt_llm::layers::DecoderDomain( - maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, speculativeDecodingModule); + maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, speculativeDecodingModule); mDynamicDecodeLayer = std::make_shared<tensorrt_llm::layers::DynamicDecodeLayer<T>>(mode, decodingDomain, mManager); mDecodingLayerWorkspace = std::make_unique<tensorrt_llm::runtime::DecodingLayerWorkspace>( @@ -65,7 +65,7 @@ void GptDecoder<T>::disableLookahead( mDecodingMode = executor::DecodingMode::TopKTopP(); auto const decodingDomain - = tensorrt_llm::layers::DecoderDomain(mMaxBatchSize, 1, mVocabSize, mVocabSizePadded, nullptr); + = tensorrt_llm::layers::DecoderDomain(mMaxNumSequences, 1, mVocabSize, mVocabSizePadded, nullptr); auto setupParams = std::make_shared<layers::DynamicDecodeSetupParams>(); @@ -286,7 +286,7 @@ std::shared_ptr<tl::StopCriteriaDecodingInputs> prepareStopCriteriaInputs(Decodi } void prepareMedusaInputs( - DecodingInput const& inputs, size_t maxBatchSize, std::shared_ptr<tl::DecodingInputs>& baseInputs) + DecodingInput const& inputs, size_t maxNumSequences, std::shared_ptr<tl::DecodingInputs>& baseInputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -303,7 +303,7 @@ void prepareMedusaInputs( { std::vector<std::vector<TensorPtr>> medusaLogits; auto const batchSize = medusaInputs.medusaLogits.size(); - medusaLogits.resize(maxBatchSize); + medusaLogits.resize(maxNumSequences); for (size_t bi = 0; bi < batchSize; ++bi) { auto const slot = batchSlots[bi]; @@ -412,7 +412,7 @@ void prepareEagleInput(DecodingInput const& inputs, std::shared_ptr<tl::Decoding template <typename T> std::shared_ptr<tl::BaseDecodingInputs> prepareInputs( - DecodingInput const& input, size_t maxBatchSize, tle::DecodingMode const& decodingMode) + DecodingInput const& input, size_t maxNumSequences, tle::DecodingMode const& decodingMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -499,7 +499,7 @@ std::shared_ptr<tl::BaseDecodingInputs> prepareInputs( // Speculative decoding if (decodingMode.isMedusa()) { - prepareMedusaInputs(input, maxBatchSize, forwardParams); + prepareMedusaInputs(input, maxNumSequences, forwardParams); } else if (decodingMode.isExplicitDraftTokens()) { @@ -739,7 +739,7 @@ void GptDecoder<T>::forwardAsync(DecodingOutput& output, DecodingInput const& in { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto forwardParams = prepareInputs<T>(input, mMaxBatchSize, mDecodingMode); + auto forwardParams = prepareInputs<T>(input, mMaxNumSequences, mDecodingMode); auto outputParams = prepareOutputs(output, mDecodingMode); mDynamicDecodeLayer->forwardAsync(outputParams, forwardParams, mDecodingLayerWorkspace); @@ -750,7 +750,7 @@ template <typename T> void GptDecoder<T>::forwardSync(DecodingOutput& output, DecodingInput const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto forwardParams = prepareInputs<T>(input, mMaxBatchSize, mDecodingMode); + auto forwardParams = prepareInputs<T>(input, mMaxNumSequences, mDecodingMode); auto outputParams = prepareOutputs(output, mDecodingMode); mDynamicDecodeLayer->forwardSync(outputParams, forwardParams, mDecodingLayerWorkspace); diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index fa4cba2d1e..6df7b1634b 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -72,11 +72,11 @@ void GptDecoderBatched::disableLookahead(RequestVector const& genRequests, Tenso TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, +void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - TLLM_CHECK(maxBatchSize > 0); + TLLM_CHECK(maxNumSequences > 0); TLLM_CHECK(maxBeamWidth > 0); std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModulePtr = nullptr; @@ -92,8 +92,8 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max auto const vocabSize = modelConfig.getVocabSize(); auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); - mDecoder = IGptDecoder::create(mode, dtype, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, mDecoderStream, - speculativeDecodingModulePtr); + mDecoder = IGptDecoder::create(mode, dtype, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, + mDecoderStream, speculativeDecodingModulePtr); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index a9d0d4009f..494788c228 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -71,6 +71,8 @@ add_library( moeUtilOp.cpp moeCommOp.cpp moeLoadBalanceOp.cpp + mxFp4BlockScaleMoe.cpp + mxFp8Quantize.cpp fp8BlockScaleMoe.cpp fp8PerTensorScaleMoe.cpp fp4BlockScaleMoe.cpp diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index b38aea3ecd..7f719524f9 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -163,9 +163,9 @@ public: { size_t size = input.numel(); size_t seq_len = input.size(0); + size_t bytes_per_element = input.element_size(); + TLLM_LOG_DEBUG("All reduce message size is %zu", size * bytes_per_element); - // If strategy is set to UB, UB must be used as UB impl output is special and cannot be used - // by others. AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size); // Log runtime strategy @@ -177,6 +177,8 @@ public: { case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias); + case AllReduceStrategyType::NCCL_SYMMETRIC: + return runNCCLAllReduceSymmetric(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::MIN_LATENCY: case AllReduceStrategyType::ONESHOT: case AllReduceStrategyType::TWOSHOT: @@ -303,6 +305,39 @@ private: return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output); } + std::vector<torch::Tensor> runNCCLAllReduceSymmetric(torch::Tensor const& input, + torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight, + torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept + { + + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + int size = input.numel(); + auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance(); + auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr()); + if (ub_buffer0.invalid()) + { + auto [symmetric_input, symmetric_ub_buffer0] + = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(), + cudaMemcpyDeviceToDevice, stream); + ub_buffer0 = symmetric_ub_buffer0; + } + + TLLM_CHECK(!ub_buffer0.invalid()); + auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + + NCCLCHECK(ncclAllReduce( + ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream)); + + if (mOp == AllReduceFusionOp::NONE) + { + return {norm_out}; + } + + // Treat any other patterns as fallback cases. + return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out); + } + std::vector<torch::Tensor> runLowPrecisionAllReduce(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept @@ -487,7 +522,7 @@ private: output_shape[r - 1] = k / 2; quant_out = at::detail::empty_cuda(output_shape, FLOAT4_E2M1X2, input.device(), std::nullopt); - scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sf_vec_size)}, + scale_out = at::detail::empty_cuda({tensorrt_llm::computeSwizzledLayoutSFSize(m, k / sf_vec_size)}, SF_DTYPE, input.device(), std::nullopt); residual_out = torch::empty_like(residual.value()); @@ -633,6 +668,10 @@ private: { runtime_strategy = AllReduceStrategyType::NCCL; } + else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) + { + runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC; + } else { // This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set. @@ -658,6 +697,11 @@ private: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank); break; } + case AllReduceStrategyType::NCCL_SYMMETRIC: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank); + break; + } case AllReduceStrategyType::MIN_LATENCY: { TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank); @@ -673,7 +717,7 @@ private: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank); break; } - default: break; + default: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UNKNOWN: %d", rank, strategy); break; } } @@ -1108,7 +1152,7 @@ std::vector<torch::Tensor> moe_finalize_allreduce(torch::Tensor const& input, to } at::Tensor mnnvlTwoShotAllReduce( - at::Tensor& input, at::Tensor& comm_buffer, at::Tensor& buffer_flags, bool wait_for_results) + at::Tensor& input, at::Tensor& comm_buffer, at::Tensor& buffer_flags, int64_t buffer_size, bool wait_for_results) { auto* mcast_mem = tensorrt_llm::common::findMcastDevMemBuffer(comm_buffer.data_ptr()); TORCH_CHECK(mcast_mem != nullptr, "two_shot_all_reduce: comm_buffer must be obtained from a mcastBuffer instance."); @@ -1120,6 +1164,7 @@ at::Tensor mnnvlTwoShotAllReduce( allreduce_params.dtype = dtype; allreduce_params.output = output.data_ptr(); allreduce_params.input = input.data_ptr(); + allreduce_params.buffer_size = static_cast<uint32_t>(buffer_size); allreduce_params.buffer_flags = buffer_flags.data_ptr(); allreduce_params.wait_for_results = wait_for_results; allreduce_params.stream = at::cuda::getCurrentCUDAStream(output.get_device()); @@ -1137,7 +1182,7 @@ at::Tensor mnnvlTwoShotAllReduce( } std::vector<torch::Tensor> twoShotRMSNorm(torch::Tensor const& comm_buf, torch::Tensor const& gamma, double epsilon, - torch::Tensor const& residual, torch::Tensor& buffer_flags) + torch::Tensor const& residual, torch::Tensor& buffer_flags, int64_t buffer_size) { auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(comm_buf.scalar_type()); auto rmsnorm_params = tensorrt_llm::kernels::mnnvl::RMSNormParams(); @@ -1153,6 +1198,7 @@ std::vector<torch::Tensor> twoShotRMSNorm(torch::Tensor const& comm_buf, torch:: rmsnorm_params.gamma = gamma.data_ptr(); rmsnorm_params.epsilon = epsilon; rmsnorm_params.residual = residual.data_ptr(); + rmsnorm_params.buffer_size = static_cast<uint32_t>(buffer_size); rmsnorm_params.buffer_flags = reinterpret_cast<uint32_t*>(buffer_flags.data_ptr()); rmsnorm_params.batch = normed_output.size(0); rmsnorm_params.hidden_dim = normed_output.size(1); @@ -1168,10 +1214,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mnnvl_twoshot_allreduce(Tensor(input!) input, Tensor(comm_buf!) comm_buffer, " - "Tensor(buffer_flags!) buffer_flags, bool wait_for_result) -> Tensor"); + "Tensor(buffer_flags!) buffer_flags, int buffer_size, bool wait_for_result) -> Tensor"); m.def( "mnnvl_twoshot_rmsnorm(Tensor comm_buf, Tensor gamma, " - "float epsilon, Tensor residual, Tensor buffer_flags) -> Tensor[]"); + "float epsilon, Tensor residual, Tensor buffer_flags, int buffer_size) -> Tensor[]"); m.def( "allreduce(" "Tensor input," diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 70976b27e5..55c285be7b 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -21,6 +21,7 @@ #include "tensorrt_llm/kernels/mlaKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" +#include "tensorrt_llm/thop/attentionOp.h" #include "tensorrt_llm/thop/thUtils.h" #include <cstdint> #include <functional> @@ -79,7 +80,8 @@ public: torch::optional<torch::Tensor> mla_context_paged_kv, torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets, torch::optional<torch::Tensor> softmax_stats_tensor, - c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params) const + c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params, + torch::optional<torch::Tensor> attention_sinks) const = 0; }; @@ -133,7 +135,8 @@ public: torch::optional<torch::Tensor> mla_context_paged_kv, torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets, torch::optional<torch::Tensor> softmax_stats_tensor, - c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params) const override + c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params, + torch::optional<torch::Tensor> attention_sinks) const override { auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); T* attention_input = static_cast<T*>(qkv.slice(0, token_offset).data_ptr()); @@ -151,7 +154,10 @@ public: { rotary_inv_freq_ptr = rotary_inv_freq.value().data_ptr<float>(); } - rotary_cos_sin_ptr = static_cast<float2 const*>(rotary_cos_sin.value().data_ptr()); + if (rotary_cos_sin.has_value()) + { + rotary_cos_sin_ptr = static_cast<float2 const*>(rotary_cos_sin.value().data_ptr()); + } } void* workspace_ptr = workspace.data_ptr(); @@ -206,22 +212,30 @@ public: // Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same // unless each layer has different attention window sizes. // the kv_cache capacity. - int const max_attention_window_size - = beam_width == 1 ? attention_window_size : cache_indirection.value().size(2); + int const max_attention_window_size = beam_width == 1 ? attention_window_size + : cache_indirection.has_value() ? cache_indirection.value().size(2) + : attention_window_size; // The cyclic_attention_window_size will determine the cyclic kv cache position of new tokens. // Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity. int const cyclic_attention_window_size = attention_window_size; bool const can_use_one_more_block = beam_width > 1; - int max_blocks_per_sequence = op.useKVCache() ? kv_cache_block_offsets.value().size(-1) : 0; - int32_t const pool_index - = op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item<int32_t>() : 0; - int32_t const layer_idx_in_cache_pool - = op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item<int32_t>() : 0; - KVBlockArray::DataType* block_offsets = static_cast<KVBlockArray::DataType*>( - op.useKVCache() ? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr); - KVBlockArray::DataType* host_block_offsets = static_cast<KVBlockArray::DataType*>( - op.useKVCache() ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr); + int max_blocks_per_sequence + = op.useKVCache() && kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().size(-1) : 0; + int32_t const pool_index = op.useKVCache() && host_kv_cache_pool_mapping.has_value() + ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item<int32_t>() + : 0; + int32_t const layer_idx_in_cache_pool = op.useKVCache() && host_kv_cache_pool_mapping.has_value() + ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item<int32_t>() + : 0; + KVBlockArray::DataType* block_offsets + = static_cast<KVBlockArray::DataType*>(op.useKVCache() && kv_cache_block_offsets.has_value() + ? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() + : nullptr); + KVBlockArray::DataType* host_block_offsets + = static_cast<KVBlockArray::DataType*>(op.useKVCache() && host_kv_cache_block_offsets.has_value() + ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() + : nullptr); auto const cache_elem_size = (op.mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T)); auto const block_size = op.mTokensPerBlock * op.mNumKVHeads * op.mHeadSize; @@ -229,12 +243,12 @@ public: int32_t const kv_factor = op.isMLAEnabled() ? 1 : 2; auto const intra_pool_offset = layer_idx_in_cache_pool * kv_factor * bytes_per_block; - void* host_primary_pool_pointer = op.useKVCache() + void* host_primary_pool_pointer = op.useKVCache() && host_kv_cache_pool_pointers.has_value() ? reinterpret_cast<void*>( reinterpret_cast<char*>(host_kv_cache_pool_pointers.value().index({pool_index, 0}).item<int64_t>()) + intra_pool_offset) : nullptr; - void* host_secondary_pool_pointer = op.useKVCache() + void* host_secondary_pool_pointer = op.useKVCache() && host_kv_cache_pool_pointers.has_value() ? reinterpret_cast<void*>( reinterpret_cast<char*>(host_kv_cache_pool_pointers.value().index({pool_index, 1}).item<int64_t>()) + intra_pool_offset) @@ -242,19 +256,32 @@ public: float const* kv_scale_orig_quant_ptr = nullptr; float const* kv_scale_quant_orig_ptr = nullptr; - if (op.mKVCacheQuantMode.hasKvCacheQuant()) + if (op.mKVCacheQuantMode.hasKvCacheQuant() && kv_scale_orig_quant.has_value() + && kv_scale_quant_orig.has_value()) { kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr<float>(); kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr<float>(); } // For FP8 output, out_scale represents the output scale. - float const* out_scale_ptr - = (op.mFP8ContextFMHA && !op.mFuseFp4Quant) ? out_scale.value().data_ptr<float>() : nullptr; + float const* out_scale_ptr = (op.mFP8ContextFMHA && !op.mFuseFp4Quant && out_scale.has_value()) + ? out_scale.value().data_ptr<float>() + : nullptr; // For NVFP4 output, out_scale holds the global scale for scaling factors. - float const* out_sf_scale_ptr = op.mFuseFp4Quant ? out_scale.value().data_ptr<float>() : nullptr; + float const* out_sf_scale_ptr + = op.mFuseFp4Quant && out_scale.has_value() ? out_scale.value().data_ptr<float>() : nullptr; + + // The attention_sinks is a float tensor with shape [num_heads_q] + float const* attention_sinks_ptr = nullptr; + if (attention_sinks.has_value()) + { + TORCH_CHECK( + attention_sinks.value().dtype() == torch::kFloat32, "Expected attention_sinks to have float dtype"); + attention_sinks_ptr = attention_sinks.value().data_ptr<float>(); + } AttentionOp::EnqueueParams<T> common_enqueue_params; common_enqueue_params.attention_input = attention_input; + common_enqueue_params.attention_sinks = attention_sinks_ptr; common_enqueue_params.rotary_inv_freq = rotary_inv_freq_ptr; common_enqueue_params.rotary_cos_sin = rotary_cos_sin_ptr; common_enqueue_params.max_past_kv_length = max_past_kv_length; @@ -317,7 +344,9 @@ public: AttentionOp::EnqueueGenerationParams<T> enqueue_params{common_enqueue_params}; enqueue_params.beam_width = beam_width; enqueue_params.num_requests = num_requests; - enqueue_params.cache_indir = beam_width == 1 ? nullptr : cache_indirection.value().data_ptr<int32_t>(); + enqueue_params.cache_indir = beam_width == 1 + ? nullptr + : (cache_indirection.has_value() ? cache_indirection.value().data_ptr<int32_t>() : nullptr); enqueue_params.semaphores = op.multiBlockSemaphores(); enqueue_params.host_past_key_value_lengths = host_past_key_value_lengths.data_ptr<int32_t>(); enqueue_params.start_token_idx_sf = token_offset; @@ -392,31 +421,31 @@ using RunnerPtr = std::shared_ptr<torch_ext::trtllm::attention::RunnerBase>; using torch_ext::trtllm::attention::Runner; using torch_ext::trtllm::attention::AttentionInputType; -void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, - torch::Tensor& output, torch::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype, - torch::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, +void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output, + std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype, + std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types, - torch::optional<torch::Tensor> kv_cache_block_offsets, torch::optional<torch::Tensor> host_kv_cache_block_offsets, - torch::optional<torch::Tensor> host_kv_cache_pool_pointers, - torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection, - torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig, - torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq, - torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache, - torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq, bool const is_fused_qkv, - bool const update_kv_cache, int64_t const predicted_tokens_per_seq, int64_t const layer_idx, - int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size, + std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets, + std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping, + std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant, + std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale, + std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin, + std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe, + std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks, + bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq, + int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, - c10::ArrayRef<double> rotary_embedding_scales, c10::ArrayRef<int64_t> rotary_embedding_max_position_info, + std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info, bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable, std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim, std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim, - torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas, + std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas, std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets, std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor, - c10::List<bool> spec_decoding_bool_params, c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params) + std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params) { TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx); // Use these tensors to infer if the attention is using KV cache @@ -544,6 +573,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch: static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq), static_cast<int>(layer_num)}; + op->mFP8ContextMLA = tensorrt_llm::common::getSMVersion() == 120 && op->mKVCacheQuantMode.hasFp8KvCache(); op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim; op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache(); // only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64 @@ -646,7 +676,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch: host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv, - mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params); + mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks); } if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly)) @@ -662,7 +692,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch: host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv, - mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params); + mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks); } TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx); @@ -721,77 +751,5 @@ bool attention_supports_nvfp4_output(int64_t const num_heads, int64_t const num_ TORCH_LIBRARY_FRAGMENT(trtllm, m) { - m.def( - "attention_inplace(" - "Tensor q" - ", Tensor? k" - ", Tensor? v" - ", Tensor(a!) output" - ", Tensor(b!)? output_sf" - ", ScalarType? out_dtype" - ", Tensor? workspace" - ", Tensor sequence_length" - ", Tensor host_past_key_value_lengths" - ", Tensor context_lengths" - ", Tensor host_context_lengths" - ", Tensor host_request_types" - ", Tensor? kv_cache_block_offsets" - ", Tensor? host_kv_cache_block_offsets" - ", Tensor? host_kv_cache_pool_pointers" - ", Tensor? host_kv_cache_pool_mapping" - ", Tensor? cache_indirection" - ", Tensor? kv_scale_orig_quant" - ", Tensor? kv_scale_quant_orig" - ", Tensor? out_scale" - ", Tensor? rotary_inv_freq" - ", Tensor? rotary_cos_sin" - ", Tensor? latent_cache" - ", Tensor? q_pe" - ", Tensor? block_ids_per_seq" - ", bool is_fused_qkv" - ", bool update_kv_cache" - ", int predicted_tokens_per_seq" - ", int layer_idx" - ", int num_heads" - ", int num_kv_heads" - ", int head_size" - ", SymInt? tokens_per_block" - ", SymInt max_num_requests" - ", SymInt max_context_length" - ", SymInt attention_window_size" - ", int sink_token_length" - ", int beam_width" - ", int mask_type" - ", int quant_mode" - ", float q_scaling" - ", int position_embedding_type" - ", int rotary_embedding_dim" - ", float rotary_embedding_base" - ", int rotary_embedding_scale_type" - ", float[] rotary_embedding_scales" - ", int[] rotary_embedding_max_position_info" - ", bool use_paged_context_fmha" - ", int? attention_input_type" - ", bool is_mla_enable" - ", int? q_lora_rank" - ", int? kv_lora_rank" - ", int? qk_nope_head_dim" - ", int? qk_rope_head_dim" - ", int? v_head_dim" - ", Tensor? mrope_rotary_cos_sin" - ", Tensor? mrope_position_deltas" - ", Tensor? mla_context_paged_kv" - ", Tensor? mla_context_kv_cache_block_offsets" - ", int? attention_chunk_size" - ", Tensor? softmax_stats_tensor" - ", bool[] spec_decoding_bool_params" - ", Tensor?[] spec_decoding_tensor_params" - ") -> ()"); - m.def("attention_supports_nvfp4_output", &torch_ext::attention_supports_nvfp4_output); } - -TORCH_LIBRARY_IMPL(trtllm, CUDA, m) -{ - m.impl("attention_inplace", &torch_ext::attention_inplace); -} diff --git a/cpp/tensorrt_llm/thop/attentionOp.h b/cpp/tensorrt_llm/thop/attentionOp.h new file mode 100644 index 0000000000..9b827b177a --- /dev/null +++ b/cpp/tensorrt_llm/thop/attentionOp.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <optional> +#include <torch/extension.h> + +namespace torch_ext +{ + +/** + * @brief Attention operation for TensorRT-LLM + * + * This function performs multi-head attention computation in-place, supporting both + * context and generation phases with various optimization features including: + * - Fused QKV processing + * - KV cache management + * - Multiple position embedding types (RoPE, ALiBi, etc.) + * - Quantization support (FP8, FP4, etc.) + * - Multi-layer attention (MLA) + * - Speculative decoding + */ +void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output, + std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype, + std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, + torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types, + std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets, + std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping, + std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant, + std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale, + std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin, + std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe, + std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks, + bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq, + int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size, + std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length, + int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, + int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type, + int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, + std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info, + bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable, + std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim, + std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim, + torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas, + std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets, + std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor, + std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params); + +} // namespace torch_ext diff --git a/cpp/tensorrt_llm/thop/fp4BatchedQuantize.cpp b/cpp/tensorrt_llm/thop/fp4BatchedQuantize.cpp index 7c6d65a5c8..01368ee384 100644 --- a/cpp/tensorrt_llm/thop/fp4BatchedQuantize.cpp +++ b/cpp/tensorrt_llm/thop/fp4BatchedQuantize.cpp @@ -57,16 +57,16 @@ std::tuple<at::Tensor, at::Tensor> fp4_batched_quantize( outputShape[rank - 1] = k / 2; at::Tensor valueE2M1 = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, self.device(), /* stride */ std::nullopt); - at::Tensor scaleFP8SF = at::detail::empty_cuda({b, tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)}, + at::Tensor scaleFP8SF = at::detail::empty_cuda({b, tensorrt_llm::computeSwizzledLayoutSFSize(m, k / sfVecSize)}, SF_DTYPE, self.device(), /* stride */ std::nullopt); // 2D tensor const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); #define LAUNCH_FP4_QUANTIZE_KERNEL(T) \ - tensorrt_llm::kernels::invokeBatchedFP4Quantization(b, m, k, reinterpret_cast<T*>(self.data_ptr()), \ + tensorrt_llm::kernels::invokeFP4Quantization(b, m, k, reinterpret_cast<T*>(self.data_ptr()), \ globalScale.data_ptr<float>(), reinterpret_cast<int64_t*>(valueE2M1.data_ptr()), \ - reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), sfUseUE8M0, mMultiProcessorCount, \ - at::cuda::getCurrentCUDAStream(self.get_device())); + reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), sfUseUE8M0, tensorrt_llm::QuantizationSFLayout::SWIZZLED, \ + mMultiProcessorCount, at::cuda::getCurrentCUDAStream(self.get_device())); if (self.scalar_type() == at::ScalarType::Half) { diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index 10b39bc335..b2dd85b5b2 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -61,14 +61,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); } - if (n_group.has_value()) + if (n_group.has_value() && n_group.value() != 0) { TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3, "Routing kernel with groups implies DeepSeekV3 routing method."); TORCH_CHECK(topk_group.has_value(), "if n_group is given, topk_group must be given"); TORCH_CHECK(num_experts % n_group.value() == 0, "num_experts must be divisible by n_group"); - TORCH_CHECK(top_k <= 8, "Current routing kernel (with groups) only supports top_k<=8."); - TORCH_CHECK(topk_group.value() <= 4, "Current routing kernel only (with groups) supports topk_group<=4."); + TORCH_CHECK(top_k <= 8 && top_k > 0, "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."); + TORCH_CHECK(topk_group.value() <= 4 && topk_group.value() > 0, + "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."); TORCH_CHECK(topk_group.value() <= n_group.value(), "n_group must not be smaller than topk_group."); // This check ensures we have enough experts in the selected groups to handle the top_k routing TORCH_CHECK(top_k < (topk_group.value() * num_experts / n_group.value()), @@ -77,7 +78,8 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize || static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive) { - TORCH_CHECK(top_k == 8, "Current routing kernel (no groups, renormalize) only supports top_k=8."); + TORCH_CHECK(top_k <= 8 && top_k > 0, + "Current routing kernel (no groups, renormalize) only supports top_k<=8 && top_k>0."); } else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) { @@ -110,8 +112,8 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r // * 2 to compensate for the fact that sizeof(hidden_states.dtype) is 1 because we pack 2 e2m1 into 1 byte. args.hidden_size = hidden_states.sizes()[1] * 2; args.top_k = top_k; - args.n_group = n_group.value_or(1); - args.topk_group = topk_group.value_or(top_k); + args.n_group = n_group.value_or(0); + args.topk_group = topk_group.value_or(0); args.local_expert_offset = local_expert_offset; args.local_num_experts = local_num_experts; args.routed_scaling_factor = routed_scaling_factor.value_or(1.0); @@ -143,8 +145,9 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r at::Tensor gemm1_output = at::detail::empty_cuda({max_num_padded_tokens, intermediate_size / 2}, at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt); - at::Tensor gemm1_output_scale = at::detail::empty_cuda({max_num_padded_tokens, intermediate_size / 16}, - at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt); + int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens, intermediate_size / 16); + at::Tensor gemm1_output_scale + = at::detail::empty_cuda({sf_size}, at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt); at::Tensor gemm2_output = at::detail::empty_cuda( {max_num_padded_tokens, args.hidden_size}, at::ScalarType::BFloat16, hidden_states.device(), std::nullopt); @@ -158,12 +161,6 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r at::Tensor num_non_exiting_ctas = at::empty({}, at::TensorOptions().device(routing_logits.device()).dtype(at::ScalarType::Int)); - // FIXME: check shape - auto const hidden_states_scale_linear_size - = tensorrt_llm::computeFP4LinearLayoutSFSize(args.num_tokens, args.hidden_size / 16); - at::Tensor hidden_states_scale_linear = at::detail::empty_cuda( - hidden_states_scale_linear_size, at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt); - // // TopK routing // @@ -188,7 +185,7 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r TORCH_CHECK(hidden_states_scale.dim() == 1, "hidden_states_scale must be 1D."); TORCH_CHECK(hidden_states_scale.sizes()[0] - == tensorrt_llm::computeFP4LinearLayoutSFSize(args.num_tokens, args.hidden_size / 16), + == tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, args.hidden_size / 16), "hidden_states_scale has incorrect size"); TORCH_CHECK(gemm1_weights.scalar_type() == FLOAT4_E2M1X2, "gemm1_weights must be byte."); @@ -256,8 +253,6 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr<int>(); workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr<int>(); - workspace.hidden_states_scale_linear = hidden_states_scale_linear.data_ptr(); - // gemm1 intermediate ws workspace.gemm1_output = gemm1_output.data_ptr(); workspace.gemm1_output_scale = reinterpret_cast<float*>(gemm1_output_scale.data_ptr()); diff --git a/cpp/tensorrt_llm/thop/fp4Op.cpp b/cpp/tensorrt_llm/thop/fp4Op.cpp index 75de02586e..54746be1c7 100644 --- a/cpp/tensorrt_llm/thop/fp4Op.cpp +++ b/cpp/tensorrt_llm/thop/fp4Op.cpp @@ -16,6 +16,7 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/thop/fp8Op.h" #include "tensorrt_llm/thop/thUtils.h" #include <cuda_fp16.h> @@ -104,46 +105,6 @@ float e2M1ToFloat(uint8_t value) return result; } -// Given the rowIdx and colIdx in the unswizzled SFMatrix, compute the 1D offset in the swizzled SFMatrix. -// colIdx and totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed. -int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, tensorrt_llm::FP4QuantizationSFLayout layout) -{ - constexpr int kColumnGroup0Size = 4; - constexpr int kRowGroup0Size = 32; - constexpr int kRowGroup1Size = kRowGroup0Size * 4; - - // Swizzled layout is used as default layout. - if (layout == tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED) - { - // int paddedRow = PadUpFn(totalRow, 128); - int paddedColumn = PadUpFn(totalColumn, 4); - - int columnIdxInGroup0 = colIdx % kColumnGroup0Size; - int columnGroupIdx = colIdx / kColumnGroup0Size; - constexpr int columnGroupStride = kColumnGroup0Size * kRowGroup1Size; - - int rowIdxInGroup0 = rowIdx % kRowGroup0Size; - int rowIdxInGroup1 = rowIdx % kRowGroup1Size / kRowGroup0Size; - int rowGroupIdx = rowIdx / kRowGroup1Size; - constexpr int rowGroup1Stride = kColumnGroup0Size; - constexpr int rowGroup0Stride = kColumnGroup0Size * rowGroup1Stride; - int rowGroupStride = kRowGroup1Size * paddedColumn; - - return columnIdxInGroup0 + columnGroupIdx * columnGroupStride + rowIdxInGroup0 * rowGroup0Stride - + rowIdxInGroup1 * rowGroup1Stride + rowGroupIdx * rowGroupStride; - } - // Linear layout is only used in E2M1AndUFP8SFScaleToFloatV2. - else if (layout == tensorrt_llm::FP4QuantizationSFLayout::LINEAR) - { - // no padding needed. totalColumn is multiple of kVecSize. - return rowIdx * totalColumn + colIdx; - } - else - { - TLLM_THROW("Other layout not implemented yet."); - } -} - torch::autograd::variable_list FloatToE2M1AndUFP8SFScale( th::Tensor floatTensor, int64_t sfVecSize, int64_t sfType, torch::optional<bool> isSfSwizzledLayout) { @@ -153,7 +114,7 @@ torch::autograd::variable_list FloatToE2M1AndUFP8SFScale( TORCH_CHECK(inputShape[1] % sfVecSize == 0); th::Tensor valueE2M1 = th::zeros({inputShape[0], inputShape[1] / 2}, th::dtype(FLOAT4_E2M1X2).requires_grad(false)); th::Tensor scaleFP8SF - = th::zeros({tensorrt_llm::computeFP4SwizzledLayoutSFSize(inputShape[0], inputShape[1] / sfVecSize)}, + = th::zeros({tensorrt_llm::computeSwizzledLayoutSFSize(inputShape[0], inputShape[1] / sfVecSize)}, th::dtype(SF_DTYPE).requires_grad(false)); th::Tensor repFloat = th::zeros(inputShape, th::dtype(th::kFloat32).requires_grad(false)); @@ -162,9 +123,9 @@ torch::autograd::variable_list FloatToE2M1AndUFP8SFScale( int groupsPerHiddenDim = hiddenDim / sfVecSize; // Note: if isSfSwizzledLayout is provided, use its value; otherwise default to true. - tensorrt_llm::FP4QuantizationSFLayout layout = isSfSwizzledLayout.value_or(true) - ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED - : tensorrt_llm::FP4QuantizationSFLayout::LINEAR; + tensorrt_llm::QuantizationSFLayout layout = isSfSwizzledLayout.value_or(true) + ? tensorrt_llm::QuantizationSFLayout::SWIZZLED + : tensorrt_llm::QuantizationSFLayout::LINEAR; for (size_t vIdx = 0; vIdx < static_cast<size_t>(inputShape[0]); ++vIdx) { @@ -247,7 +208,7 @@ torch::autograd::variable_list HalfToE2M1AndUFP8SFScale( auto rows = has_experts ? inputShape[1] : inputShape[0]; auto cols = has_experts ? inputShape[2] : inputShape[1]; - auto const expert_sf_size = tensorrt_llm::computeFP4SwizzledLayoutSFSize(rows, cols / sfVecSize); + auto const expert_sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols / sfVecSize); TORCH_CHECK(cols % sfVecSize == 0); std::array<int64_t, 3> shape{num_experts, rows, cols / 2}; @@ -263,11 +224,11 @@ torch::autograd::variable_list HalfToE2M1AndUFP8SFScale( size_t const expert_sf_offset = expert_sf_size * eIdx; constexpr int FP4_PER_INT64 = 16; constexpr int FP8_PER_INT32 = 4; - tensorrt_llm::kernels::invokeFP4Quantization(rows, cols, + tensorrt_llm::kernels::invokeFP4Quantization(1, rows, cols, reinterpret_cast<half*>(halfTensor.data_ptr()) + expert_elem_offset, globalScale.data_ptr<float>() + eIdx, reinterpret_cast<int64_t*>(valueE2M1.data_ptr()) + expert_elem_offset / FP4_PER_INT64, reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()) + expert_sf_offset / FP8_PER_INT32, sfType == 0, - tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, 0); + tensorrt_llm::QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, 0); } return {valueE2M1, scaleFP8SF}; @@ -276,7 +237,7 @@ torch::autograd::variable_list HalfToE2M1AndUFP8SFScale( // Interleave (and possibly pad) the weights block scaling factor. // blockScale: [num_experts, rows, cols] or [rows, cols] // Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4) -th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale) +th::Tensor BlockScaleInterleave(th::Tensor const& blockScale) { bool is_cuda = blockScale.device().is_cuda(); if (is_cuda) @@ -293,7 +254,7 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale) auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0]; auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1]; - auto expert_out_size = tensorrt_llm::computeFP4SwizzledLayoutSFSize(rows, cols); + auto expert_out_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols); auto rows_padded = PadUpFn(rows, 128); auto cols_padded = PadUpFn(cols, 4); TORCH_CHECK( @@ -305,7 +266,7 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale) { const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount(); auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device()); - tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleave(num_experts, rows, rows_padded, cols, cols_padded, + tensorrt_llm::kernels::invokeBlockScaleInterleave(num_experts, rows, rows_padded, cols, cols_padded, blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream); } else @@ -325,8 +286,7 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale) { sf_ori = blockScalePtr[cIdx]; } - int sf_index - = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED); + int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::QuantizationSFLayout::SWIZZLED); interleavedBlockScalePtr[sf_index] = sf_ori; } } @@ -340,7 +300,7 @@ th::Tensor NVFP4BlockScaleInterleave(th::Tensor const& blockScale) // blockScale: [num_experts, rows, cols] or [rows, cols] // Note: rows and cols are the dimensions of the original unswizzled SFMatrix, so reshape input before passing into // this function! Return: The same shape as blockScale -th::Tensor NVFP4BlockScaleInterleaveReverse(th::Tensor const& blockScale) +th::Tensor BlockScaleInterleaveReverse(th::Tensor const& blockScale) { bool is_cuda = blockScale.device().is_cuda(); if (is_cuda) @@ -366,7 +326,7 @@ th::Tensor NVFP4BlockScaleInterleaveReverse(th::Tensor const& blockScale) { const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount(); auto stream = at::cuda::getCurrentCUDAStream(blockScale.get_device()); - tensorrt_llm::kernels::invokeNVFP4BlockScaleInterleaveReverse(num_experts, rows, cols, + tensorrt_llm::kernels::invokeBlockScaleInterleaveReverse(num_experts, rows, cols, blockScale.data_ptr<uint8_t>(), static_cast<uint8_t*>(reversedBlockScale.data_ptr()), smCount, stream); } else @@ -379,8 +339,7 @@ th::Tensor NVFP4BlockScaleInterleaveReverse(th::Tensor const& blockScale) { for (int cIdx = 0; cIdx < cols; ++cIdx) { - int sf_index - = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED); + int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, tensorrt_llm::QuantizationSFLayout::SWIZZLED); identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx}; } } @@ -424,7 +383,7 @@ th::Tensor E2M1AndUFP8SFScaleToFloat(th::Tensor valueE2M1, th::Tensor scaleFP8SF uint8_t* packedFp4Ptr = valueE2M1.data_ptr<uint8_t>() + vIdx * packedFp4HiddenDim + group * sfVecSize / 2; uint8_t* scaleFP8SFPtr = scaleFP8SF.data_ptr<uint8_t>(); uint8_t fp8Scale = scaleFP8SFPtr[computeSFIndex( - vIdx, group, packedShape[0], groupsPerHiddenDim, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED)]; + vIdx, group, packedShape[0], groupsPerHiddenDim, tensorrt_llm::QuantizationSFLayout::SWIZZLED)]; int scale = fp8Scale; if (sfType == 0) { @@ -453,8 +412,8 @@ th::Tensor E2M1AndUFP8SFScaleToFloat(th::Tensor valueE2M1, th::Tensor scaleFP8SF } // Used by the (fp16 -> int4) quant layer + int4 gemm network. -th::Tensor E2M1AndUFP8SFScaleToFloatV2(th::Tensor valueE2M1, th::Tensor scaleFP8SF, th::Tensor globalScale, - int64_t sfVecSize, int64_t sfType, bool isSfSwizzledLayout = true) +th::Tensor E2M1AndUFP8SFScaleToFloatV2(th::Tensor valueE2M1, th::Tensor scaleFP8SF, + std::optional<th::Tensor> globalScale, int64_t sfVecSize, int64_t sfType, bool isSfSwizzledLayout = true) { CHECK_CPU_INPUT(valueE2M1, FLOAT4_E2M1X2); CHECK_CPU_INPUT(scaleFP8SF, SF_DTYPE); @@ -465,15 +424,20 @@ th::Tensor E2M1AndUFP8SFScaleToFloatV2(th::Tensor valueE2M1, th::Tensor scaleFP8 th::Tensor floatTensor = th::zeros({packedShape[0], packedShape[1] * 2}, th::dtype(th::kFloat32).requires_grad(false)); - CHECK_CPU_INPUT(globalScale, th::kFloat32); - float globalScaleVal = globalScale.data_ptr<float>()[0]; + float globalScaleVal{1.0f}; + if (sfType == 1) + { + TORCH_CHECK(globalScale.has_value(), "globalScale is required when sfType is 1."); + CHECK_CPU_INPUT(globalScale.value(), th::kFloat32); + globalScaleVal = globalScale->data_ptr<float>()[0]; + } int hiddenDim = packedShape[1] * 2; int packedFp4HiddenDim = hiddenDim / 2; int groupsPerHiddenDim = hiddenDim / sfVecSize; - tensorrt_llm::FP4QuantizationSFLayout layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED - : tensorrt_llm::FP4QuantizationSFLayout::LINEAR; + tensorrt_llm::QuantizationSFLayout layout = isSfSwizzledLayout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED + : tensorrt_llm::QuantizationSFLayout::LINEAR; for (size_t vIdx = 0; vIdx < static_cast<size_t>(packedShape[0]); ++vIdx) { @@ -526,18 +490,18 @@ static auto e2m1_and_ufp8sf_scale_to_float_v2 = torch::RegisterOperators( TORCH_LIBRARY_FRAGMENT(trtllm, m) { - m.def("nvfp4_block_scale_interleave(Tensor input) -> Tensor"); - m.def("nvfp4_block_scale_interleave_reverse(Tensor input) -> Tensor"); + m.def("block_scale_interleave(Tensor input) -> Tensor"); + m.def("block_scale_interleave_reverse(Tensor input) -> Tensor"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { - m.impl("nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave); - m.impl("nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse); + m.impl("block_scale_interleave", &torch_ext::BlockScaleInterleave); + m.impl("block_scale_interleave_reverse", &torch_ext::BlockScaleInterleaveReverse); } TORCH_LIBRARY_IMPL(trtllm, CPU, m) { - m.impl("nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave); - m.impl("nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse); + m.impl("block_scale_interleave", &torch_ext::BlockScaleInterleave); + m.impl("block_scale_interleave_reverse", &torch_ext::BlockScaleInterleaveReverse); } diff --git a/cpp/tensorrt_llm/thop/fp4Quantize.cpp b/cpp/tensorrt_llm/thop/fp4Quantize.cpp index 030fe3e06d..956b752324 100644 --- a/cpp/tensorrt_llm/thop/fp4Quantize.cpp +++ b/cpp/tensorrt_llm/thop/fp4Quantize.cpp @@ -24,26 +24,41 @@ #include <cuda_fp16.h> #include <cstdint> +#include <optional> namespace torch_ext { // self: [M, K], fp16/bf16/fp8_quantized -// globalScale: [1] float, = (448 * 6) / self.abs().max() +// globalScale: [1] float, = (448 * 6) / self.abs().max(). Not used when sfUseUE8M0 is true. // nvfp4: sfVecSize = 16, sfUseUE8M0 = false -// mxfp4: sfVecSize = 32 (not supported yet), sfUseUE8M0 = true +// mxfp4: sfVecSize = 32, sfUseUE8M0 = true // alignment: sfVecSize // isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in linear layout. -// See FP4QuantizationSFLayout enum for more details about the two layouts. +// See QuantizationSFLayout enum for more details about the two layouts. // returns self_fp4, self_block_scale_factors // self_fp4: [M, K / 2], FLOAT4_E2M1X2 // self_block_scale_factors: ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) -std::tuple<at::Tensor, at::Tensor> fp4_quantize( - at::Tensor const& self, at::Tensor const& globalScale, int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout) +std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self, std::optional<at::Tensor> const& globalScale, + int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout) { CHECK_TH_CUDA(self); CHECK_CONTIGUOUS(self); - CHECK_INPUT(globalScale, torch::kFloat32); - TORCH_CHECK(sfVecSize == 16, "sfVecSize can only be 16"); + if (sfUseUE8M0) + { + TORCH_CHECK(sfVecSize == 32, "sfVecSize can only be 32, when sfUseUE8M0 is true"); + } + else + { + TORCH_CHECK(globalScale.has_value(), "globalScale is required when sfUseUE8M0 is false"); + CHECK_INPUT(globalScale.value(), torch::kFloat32); + TORCH_CHECK(sfVecSize == 16, "sfVecSize can only be 16, when sfUseUE8M0 is false"); + } + + float* globalScalePtr{nullptr}; + if (globalScale.has_value()) + { + globalScalePtr = globalScale->data_ptr<float>(); + } auto const& inputShape = self.sizes(); auto const& rank = inputShape.size(); @@ -62,46 +77,76 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize( at::Tensor valueE2M1 = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, self.device(), /* stride */ std::nullopt); - int64_t SFSize = isSfSwizzledLayout ? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize) - : tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize); + int64_t SFSize = isSfSwizzledLayout ? tensorrt_llm::computeSwizzledLayoutSFSize(m, k / sfVecSize) + : tensorrt_llm::computeLinearLayoutSFSize(m, k / sfVecSize); at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, self.device(), /* stride */ std::nullopt); // 1D tensor const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); - auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED - : tensorrt_llm::FP4QuantizationSFLayout::LINEAR; + auto const layout = isSfSwizzledLayout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED + : tensorrt_llm::QuantizationSFLayout::LINEAR; -#define LAUNCH_FP4_QUANTIZE_KERNEL(T) \ - tensorrt_llm::kernels::invokeFP4Quantization(m, k, reinterpret_cast<T*>(self.data_ptr()), \ - globalScale.data_ptr<float>(), reinterpret_cast<int64_t*>(valueE2M1.data_ptr()), \ +#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \ + tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>(1, m, k, reinterpret_cast<T*>(self.data_ptr()), \ + globalScalePtr, reinterpret_cast<int64_t*>(valueE2M1.data_ptr()), \ reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), sfUseUE8M0, layout, mMultiProcessorCount, \ at::cuda::getCurrentCUDAStream(self.get_device())); - if (self.scalar_type() == at::ScalarType::Half) - { - LAUNCH_FP4_QUANTIZE_KERNEL(half) - } - else if (self.scalar_type() == at::ScalarType::BFloat16) + if (sfUseUE8M0) { + if (self.scalar_type() == at::ScalarType::Half) + { + LAUNCH_FP4_QUANTIZE_KERNEL(half, 32) + } + else if (self.scalar_type() == at::ScalarType::BFloat16) + { #ifdef ENABLE_BF16 - LAUNCH_FP4_QUANTIZE_KERNEL(__nv_bfloat16) + LAUNCH_FP4_QUANTIZE_KERNEL(__nv_bfloat16, 32) #else - C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled to quantize an bf16 tensor to fp4."); + C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled to quantize an bf16 tensor to fp4."); #endif - } - else if (self.scalar_type() == at::ScalarType::Float8_e4m3fn) - { + } + else if (self.scalar_type() == at::ScalarType::Float8_e4m3fn) + { #ifdef ENABLE_FP8 - LAUNCH_FP4_QUANTIZE_KERNEL(__nv_fp8_e4m3) + LAUNCH_FP4_QUANTIZE_KERNEL(__nv_fp8_e4m3, 32) #else - C10_THROW_ERROR(NotImplementedError, "FP8 must be enabled to quantize an fp8 tensor to fp4."); + C10_THROW_ERROR(NotImplementedError, "FP8 must be enabled to quantize an fp8 tensor to fp4."); #endif + } + else + { + C10_THROW_ERROR(NotImplementedError, "fp4_quantize only supports input tensor with dtypes fp16/bf16/e4m3."); + } } else { - C10_THROW_ERROR(NotImplementedError, "fp4_quantize only supports input tensor with dtypes fp16/bf16/e4m3."); + if (self.scalar_type() == at::ScalarType::Half) + { + LAUNCH_FP4_QUANTIZE_KERNEL(half, 16) + } + else if (self.scalar_type() == at::ScalarType::BFloat16) + { +#ifdef ENABLE_BF16 + LAUNCH_FP4_QUANTIZE_KERNEL(__nv_bfloat16, 16) +#else + C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled to quantize an bf16 tensor to fp4."); +#endif + } + else if (self.scalar_type() == at::ScalarType::Float8_e4m3fn) + { +#ifdef ENABLE_FP8 + LAUNCH_FP4_QUANTIZE_KERNEL(__nv_fp8_e4m3, 16) +#else + C10_THROW_ERROR(NotImplementedError, "FP8 must be enabled to quantize an fp8 tensor to fp4."); +#endif + } + else + { + C10_THROW_ERROR(NotImplementedError, "fp4_quantize only supports input tensor with dtypes fp16/bf16/e4m3."); + } } #undef LAUNCH_FP4_QUANTIZE_KERNEL @@ -113,11 +158,12 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize( TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "fp4_quantize(Tensor input, Tensor globalScale, int sfVecSize, bool sfUseUE8M0=False, bool swizzedLayout=True) " + "fp4_quantize(Tensor input, Tensor? globalScale, int sfVecSize, bool sfUseUE8M0=False, bool " + "isSfSwizzledLayout=True) " "-> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { - m.impl("fp4_quantize", &torch_ext::fp4_quantize); + m.impl("fp4_quantize", TORCH_FN(torch_ext::fp4_quantize)); } diff --git a/cpp/tensorrt_llm/thop/fp4Quantize.h b/cpp/tensorrt_llm/thop/fp4Quantize.h index ea6dc1f59f..e460cb9c95 100644 --- a/cpp/tensorrt_llm/thop/fp4Quantize.h +++ b/cpp/tensorrt_llm/thop/fp4Quantize.h @@ -20,9 +20,10 @@ #include <ATen/cuda/EmptyTensor.h> #include <cstdint> +#include <optional> namespace torch_ext { -std::tuple<torch::Tensor, torch::Tensor> fp4_quantize(torch::Tensor const& self, torch::Tensor const& globalScale, +std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self, std::optional<at::Tensor> const& globalScale, int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout); } diff --git a/cpp/tensorrt_llm/thop/fp8BatchedGemmTrtllmGen.cpp b/cpp/tensorrt_llm/thop/fp8BatchedGemmTrtllmGen.cpp index 3631b8177b..be1970e480 100644 --- a/cpp/tensorrt_llm/thop/fp8BatchedGemmTrtllmGen.cpp +++ b/cpp/tensorrt_llm/thop/fp8BatchedGemmTrtllmGen.cpp @@ -40,7 +40,6 @@ void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1, std::vector<int32_t> const& batchedTokens, bool useDeepSeekFp8, bool lowLatencyKernel, tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner& runner, int32_t const configIndex) { - // numTokens and maxNumCtasInBatchDim are not used for static batching int32_t const numTokens = 0; int32_t const maxNumCtasInBatchDim = 0; @@ -115,7 +114,8 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1 TORCH_CHECK(dDqSfsB.value().scalar_type() == at::ScalarType::Float, "Scale dtype must be FP32."); TORCH_CHECK(dDqSfsA.value().dim() == 2, "batching M: dDqSfsA must be a 2D matrix"); TORCH_CHECK(dDqSfsA.value().sizes()[0] == k / dsFp8QuantBlockSize, - "batching M: dDqSfsA must have size B x K/dsFp8QuantBlockSize x divUp(m, dsFp8QuantBlockSize) * 128 * b"); + "batching M: dDqSfsA must have size B x K/dsFp8QuantBlockSize x divUp(m, dsFp8QuantBlockSize) * tileSize * " + "b"); TORCH_CHECK( dDqSfsA.value().sizes()[1] == static_cast<int64_t>(tensorrt_llm::common::divUp(m, tileSize) * tileSize * b), "batching M: dDqSfsA must have size B x K/dsFp8QuantBlockSize x divUp(m, tileSize) * tileSize * b"); @@ -207,8 +207,9 @@ public: default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3."); } - RunnerOptionsType const options = {.eltType = mEltType, - .outputType = outDtype, + RunnerOptionsType const options = {.dtypeA = mEltType, + .dtypeB = mEltType, + .dtypeC = outDtype, .deepSeekFp8 = mUseDeepSeekFp8, .fusedAct = false, .routeAct = false, diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index 26fb868ed3..f48a40620f 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -30,13 +30,13 @@ namespace btg = batchedGemm::trtllm::gen; using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType; using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; -at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor const& routing_bias, +at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, std::optional<at::Tensor> const& routing_bias, at::Tensor const& hidden_states, at::Tensor const& hidden_states_scale, at::Tensor const& gemm1_weights, at::Tensor const& gemm1_weights_scale, at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, - int64_t const num_experts, int64_t const top_k, int64_t const n_group, int64_t const topk_group, - int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, - double const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, - MoeRunnerType& moe_runner, int64_t moeConfigIndex) + int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group, + std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, + int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, + int64_t const routing_method_type, MoeRunnerType& moe_runner, int64_t moeConfigIndex) { auto const sm = tensorrt_llm::common::getSMVersion(); TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE"); @@ -45,26 +45,38 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0], "routing_logits and hidden_states must have the same number of tokens."); TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); - TORCH_CHECK( - routing_bias.scalar_type() == at::ScalarType::BFloat16 || routing_bias.scalar_type() == at::ScalarType::Float, - "routing_bias must be bfloat16 or float."); - TORCH_CHECK(routing_bias.dim() == 1, "routing_bias must be 1D."); - TORCH_CHECK(routing_bias.sizes()[0] == num_experts, "routing_bias has incorrect shape."); - if (n_group <= 0 || topk_group <= 0) + if (routing_bias.has_value()) { - TORCH_CHECK(top_k == 1, "Current routing kernel (no groups) only supports top_k=1."); + TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16, "routing_bias must be bfloat16."); + TORCH_CHECK(routing_bias.value().dim() == 1, "routing_bias must be 1D."); + TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); } - else + + if (n_group.has_value() && n_group.value() != 0) { - TORCH_CHECK(top_k <= 8, "Current routing kernel (with groups) only supports top_k<=8."); - TORCH_CHECK(topk_group <= 4, "Current routing kernel (with groups) only supports topk_group<=4."); - TORCH_CHECK(topk_group <= n_group, "n_group must not be smaller than topk_group."); - TORCH_CHECK(num_experts % n_group == 0, "num_experts must be divisible by n_group"); + TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3, + "Routing kernel with groups implies DeepSeekV3 routing method."); + TORCH_CHECK(topk_group.has_value(), "if n_group is given, topk_group must be given"); + TORCH_CHECK(num_experts % n_group.value() == 0, "num_experts must be divisible by n_group"); + TORCH_CHECK(top_k <= 8 && top_k > 0, "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."); + TORCH_CHECK(topk_group.value() <= 4 && topk_group.value() > 0, + "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."); + TORCH_CHECK(topk_group.value() <= n_group.value(), "n_group must not be smaller than topk_group."); // This check ensures we have enough experts in the selected groups to handle the top_k routing - TORCH_CHECK(top_k < (topk_group * num_experts / n_group), + TORCH_CHECK(top_k < (topk_group.value() * num_experts / n_group.value()), "top_k must be less than total number of experts in selected groups"); } + else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize + || static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive) + { + TORCH_CHECK(false, "Don't support this routing method type Renormalize(Naive)."); + } + else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) + { + TORCH_CHECK(top_k == 1, "Current routing kernel (no groups, Llama4) only supports top_k=1."); + } + TORCH_CHECK(num_experts % 4 == 0, "Routing kernel expects that num_experts must be divisible by 4"); TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k"); @@ -74,9 +86,12 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor // setup args // note: the assumption is that output data type is always Bfloat16 (the default) args.mDtypeElt = btg::Dtype::E4m3; - args.mDtypeExpW = routing_bias.scalar_type() == at::ScalarType::BFloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + auto const routing_bias_dtype + = routing_bias.has_value() ? routing_bias.value().scalar_type() : at::ScalarType::BFloat16; + args.mDtypeExpW = routing_bias_dtype == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + args.routing_logits = routing_logits.data_ptr<float>(); - args.routing_bias = routing_bias.data_ptr(); + args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; args.hidden_states = hidden_states.data_ptr(); args.hidden_states_scale = hidden_states_scale.data_ptr<float>(); args.gemm1_weights = gemm1_weights.data_ptr(); @@ -87,11 +102,11 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor args.num_experts = num_experts; args.hidden_size = hidden_states.sizes()[1]; args.top_k = top_k; - args.n_group = n_group; - args.topk_group = topk_group; + args.n_group = n_group.value_or(0); + args.topk_group = topk_group.value_or(0); args.local_expert_offset = local_expert_offset; args.local_num_experts = local_num_experts; - args.routed_scaling_factor = routed_scaling_factor; + args.routed_scaling_factor = routed_scaling_factor.value_or(1.0); args.intermediate_size = intermediate_size; args.mUseDeepSeekFp8 = true; @@ -108,7 +123,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor at::Tensor permuted_idx_to_token_idx = at::detail::empty_cuda({max_num_padded_tokens}, at::ScalarType::Int, routing_logits.device(), std::nullopt); at::Tensor expert_weights = at::detail::empty_cuda( - {args.num_tokens, args.top_k}, routing_bias.scalar_type(), routing_logits.device(), std::nullopt); + {args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt); at::Tensor expert_indexes = at::detail::empty_cuda( {args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt); int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); @@ -139,7 +154,7 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream(routing_logits.get_device()); - routing_runner.run(routing_logits.data_ptr<float>(), routing_bias.data_ptr(), args.num_tokens, args.num_experts, + routing_runner.run(routing_logits.data_ptr<float>(), args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(), total_num_padded_tokens.data_ptr<int>(), expanded_idx_to_permuted_idx.data_ptr<int>(), @@ -241,12 +256,13 @@ public: return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens); } - [[nodiscard]] at::Tensor run(at::Tensor const& routing_logits, at::Tensor const& routing_bias, + [[nodiscard]] at::Tensor run(at::Tensor const& routing_logits, std::optional<at::Tensor> const& routing_bias, at::Tensor const& hidden_states, at::Tensor const& hidden_states_scale, at::Tensor const& gemm1_weights, at::Tensor const& gemm1_weights_scale, at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, - int64_t num_experts, int64_t top_k, int64_t n_group, int64_t topk_group, int64_t intermediate_size, - int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, - int64_t routing_method_type, int64_t moeConfigIndex) + int64_t num_experts, int64_t top_k, std::optional<int64_t> const n_group, + std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, + int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t routing_method_type, + int64_t moeConfigIndex) { // Autotuner has requested a default or 'fallback' config index diff --git a/cpp/tensorrt_llm/thop/fp8Op.cpp b/cpp/tensorrt_llm/thop/fp8Op.cpp index afd6e388d8..21f56757c6 100644 --- a/cpp/tensorrt_llm/thop/fp8Op.cpp +++ b/cpp/tensorrt_llm/thop/fp8Op.cpp @@ -15,6 +15,7 @@ */ #include "tensorrt_llm/thop/fp8Op.h" +#include "cutlass/numeric_types.h" #include "tensorrt_llm/common/cudaBf16Wrapper.h" #include "tensorrt_llm/common/cudaFp8Utils.h" #include "tensorrt_llm/thop/thUtils.h" @@ -206,6 +207,122 @@ Tensor e4m3_dequantize_helper(Tensor input, Tensor scales, QuantizeMode quantize return dequantized_input; } +inline uint8_t float_to_ue8m0(float value) +{ + if (value == 0.0f) + { + return 0x00; + } + constexpr uint32_t FP32_MANTISSA_BITS = 23; + uint32_t val_u32 = *reinterpret_cast<uint32_t*>(&value); + uint8_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) + { + ++exponent; + } + return exponent; +} + +// Used in tests to quantize mxe4m3 tensors on host. +std::tuple<Tensor, Tensor> quantize_mxe4m3_host(Tensor x_fp32, bool is_sf_swizzled_layout = true) +{ + int32_t const sf_vec_size = 32; + CHECK_CPU_INPUT(x_fp32, torch::kFloat32); + auto data_shape = x_fp32.sizes(); + TORCH_CHECK(data_shape.size() == 2, "x_fp32 should be 2D tensor."); + int num_tokens = data_shape[0]; + int hidden_dim = data_shape[1]; + int groups_per_hidden_dim = hidden_dim / sf_vec_size; + + Tensor fp8_tensor = at::detail::empty_cpu( + {num_tokens, hidden_dim}, at::ScalarType::Byte, /* pinned */ true, at::MemoryFormat::Contiguous); + int64_t sf_size = is_sf_swizzled_layout + ? tensorrt_llm::computeSwizzledLayoutSFSize(num_tokens, hidden_dim / sf_vec_size) + : tensorrt_llm::computeLinearLayoutSFSize(num_tokens, hidden_dim / sf_vec_size); + Tensor scale_tensor = at::detail::empty_cpu({sf_size}, SF_DTYPE, /* pinned */ true, at::MemoryFormat::Contiguous); + + tensorrt_llm::QuantizationSFLayout layout = is_sf_swizzled_layout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED + : tensorrt_llm::QuantizationSFLayout::LINEAR; + + for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) + { + for (int group = 0; group < groups_per_hidden_dim; ++group) + { + float* fp32_ptr = x_fp32.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size; + uint8_t* fp8_ptr = fp8_tensor.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size; + + uint8_t* scale_ue8m08sf_ptr = scale_tensor.data_ptr<uint8_t>(); + + float local_amax = 0.0f; + for (int ki = 0; ki < sf_vec_size; ++ki) + { + local_amax = std::max(std::abs(fp32_ptr[ki]), local_amax); + } + + local_amax *= (1.f / 448.0f); + + uint8_t scale_ue8m0 = float_to_ue8m0(local_amax); + auto const inv_scale = (scale_ue8m0 == 0) ? 1 : exp2f(127 - static_cast<float>(scale_ue8m0)); + + scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], groups_per_hidden_dim, layout)] = scale_ue8m0; + + for (int ki = 0; ki < sf_vec_size; ++ki) + { + float const scaled_fp32_value = fp32_ptr[ki] * inv_scale; + auto fp8_value = cutlass::float_e4m3_t{scaled_fp32_value}; + fp8_ptr[ki] = *reinterpret_cast<uint8_t*>(&fp8_value); + } + } + } + return std::make_tuple(fp8_tensor, scale_tensor); +} + +// Used in tests to dequantize mxe4m3 tensors on host. +Tensor dequantize_mxe4m3_host(Tensor value_e4m3, Tensor scale_ue8m08sf, bool is_sf_swizzled_layout = true) +{ + int32_t const sf_vec_size = 32; + CHECK_CPU_INPUT(value_e4m3, at::ScalarType::Byte); + CHECK_CPU_INPUT(scale_ue8m08sf, SF_DTYPE); + auto data_shape = value_e4m3.sizes(); + auto scale_shape = scale_ue8m08sf.sizes(); + TORCH_CHECK(data_shape.size() == 2, "value_e4m3 should be 2D tensor."); + TORCH_CHECK(scale_shape.size() == 1, "scale_ue8m08sf should be 1D tensor."); + Tensor float_tensor = at::detail::empty_cpu( + {data_shape[0], data_shape[1]}, at::ScalarType::Float, /* pinned */ true, at::MemoryFormat::Contiguous); + + int hidden_dim = data_shape[1]; + int groups_per_hidden_dim = hidden_dim / sf_vec_size; + + tensorrt_llm::QuantizationSFLayout layout = is_sf_swizzled_layout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED + : tensorrt_llm::QuantizationSFLayout::LINEAR; + for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) + { + for (int group = 0; group < groups_per_hidden_dim; ++group) + { + float* float_ptr = float_tensor.data_ptr<float>() + ti * hidden_dim + group * sf_vec_size; + uint8_t* fp8_ptr = value_e4m3.data_ptr<uint8_t>() + ti * hidden_dim + group * sf_vec_size; + uint8_t* scale_ue8m08sf_ptr = scale_ue8m08sf.data_ptr<uint8_t>(); + uint8_t fp8_scale + = scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], groups_per_hidden_dim, layout)]; + + float scale_float; + uint32_t scale_float_u32 = uint32_t(fp8_scale) << 23; + memcpy(&scale_float, &scale_float_u32, sizeof(scale_float)); + + for (int ki = 0; ki < sf_vec_size; ++ki) + { + uint8_t fp8_u8_repr = fp8_ptr[ki]; + auto fp32 = static_cast<float>(*reinterpret_cast<cutlass::float_e4m3_t*>(&fp8_u8_repr)); + float value = fp32 * scale_float; + float_ptr[ki] = value; + } + } + } + return float_tensor; +} + std::tuple<Tensor, Tensor> symmetric_quantize_weight(Tensor weight) { return e4m3_quantize_helper(weight, at::nullopt, QuantizeMode::PER_CHANNEL); @@ -279,3 +396,9 @@ TORCH_LIBRARY_IMPL(tensorrt_llm, CUDA, m) m.impl("dequantize_e4m3_activation", &torch_ext::symmetric_dequantize_activation); m.impl("dequantize_e4m3_per_tensor", &torch_ext::symmetric_dequantize_per_tensor); } + +static auto dequantize_mxe4m3_host + = torch::RegisterOperators("tensorrt_llm::dequantize_mxe4m3_host", &torch_ext::dequantize_mxe4m3_host); + +static auto quantize_mxe4m3_host + = torch::RegisterOperators("tensorrt_llm::quantize_mxe4m3_host", &torch_ext::quantize_mxe4m3_host); diff --git a/cpp/tensorrt_llm/thop/fp8Op.h b/cpp/tensorrt_llm/thop/fp8Op.h index f12f166c4e..1b08935d1d 100644 --- a/cpp/tensorrt_llm/thop/fp8Op.h +++ b/cpp/tensorrt_llm/thop/fp8Op.h @@ -28,6 +28,47 @@ namespace torch_ext { +// Given the rowIdx and colIdx in the unswizzled SFMatrix, compute the 1D offset in the swizzled SFMatrix. +// colIdx and totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed. +inline int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, + tensorrt_llm::QuantizationSFLayout layout, bool useUE8M0 = false) +{ + constexpr int kColumnGroup0Size = 4; + constexpr int kRowGroup0Size = 32; + constexpr int kRowGroup1Size = kRowGroup0Size * 4; + + // Swizzled layout is used as default layout. + if (layout == tensorrt_llm::QuantizationSFLayout::SWIZZLED) + { + // int paddedRow = PadUpFn(totalRow, 128); + int paddedColumn = PadUpFn(totalColumn, 4); + + int columnIdxInGroup0 = colIdx % kColumnGroup0Size; + int columnGroupIdx = colIdx / kColumnGroup0Size; + constexpr int columnGroupStride = kColumnGroup0Size * kRowGroup1Size; + + int rowIdxInGroup0 = rowIdx % kRowGroup0Size; + int rowIdxInGroup1 = (rowIdx % kRowGroup1Size) / kRowGroup0Size; + int rowGroupIdx = rowIdx / kRowGroup1Size; + constexpr int rowGroup1Stride = kColumnGroup0Size; + constexpr int rowGroup0Stride = kColumnGroup0Size * rowGroup1Stride; + int rowGroupStride = kRowGroup1Size * paddedColumn; + + return columnIdxInGroup0 + columnGroupIdx * columnGroupStride + rowIdxInGroup0 * rowGroup0Stride + + rowIdxInGroup1 * rowGroup1Stride + rowGroupIdx * rowGroupStride; + } + // Linear layout is only used in E2M1AndUFP8SFScaleToFloatV2. + else if (layout == tensorrt_llm::QuantizationSFLayout::LINEAR) + { + // no padding needed. totalColumn is multiple of kVecSize. + return rowIdx * totalColumn + colIdx; + } + else + { + TLLM_THROW("Other layout not implemented yet."); + } +} + std::tuple<torch::Tensor, torch::Tensor> symmetric_quantize_weight(torch::Tensor weight); std::tuple<torch::Tensor, torch::Tensor> symmetric_quantize_activation(torch::Tensor activation); std::tuple<torch::Tensor, torch::Tensor> symmetric_quantize_per_tensor(torch::Tensor input); diff --git a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp index 395c4320b2..b76701f788 100644 --- a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp @@ -25,12 +25,14 @@ namespace torch_ext namespace btg = batchedGemm::trtllm::gen; using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType; -torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logits, torch::Tensor const& routing_bias, - torch::Tensor const& hidden_states, torch::Tensor const& gemm1_weights, torch::Tensor const& output1_scales_scalar, +torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logits, + torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states, + torch::Tensor const& gemm1_weights, torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& gemm2_weights, - torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k, int64_t const n_group, - int64_t const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, - int64_t const local_num_experts, double const routed_scaling_factor, bool const use_routing_scales_on_input, + torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k, + std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size, + int64_t const local_expert_offset, int64_t const local_num_experts, + std::optional<double> const routed_scaling_factor, bool const use_routing_scales_on_input, int64_t const tile_tokens_dim, int64_t const routing_method_type) { @@ -46,24 +48,38 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit } TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape."); - TORCH_CHECK(routing_bias.scalar_type() == at::ScalarType::BFloat16, "routing_bias must be bfloat16."); - TORCH_CHECK(routing_bias.dim() == 1, "routing_bias must be 1D."); - TORCH_CHECK(routing_bias.sizes()[0] == num_experts, "routing_bias has incorrect shape."); - if (n_group <= 0 || topk_group <= 0) + if (routing_bias.has_value()) { - TORCH_CHECK(top_k == 1, "Current routing kernel (no groups) only supports top_k=1."); + TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16, "routing_bias must be bfloat16."); + TORCH_CHECK(routing_bias.value().dim() == 1, "routing_bias must be 1D."); + TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); } - else + + if (n_group.has_value() && n_group.value() != 0) { - TORCH_CHECK(top_k <= 8, "Current routing kernel (with groups) only supports top_k<=8."); - TORCH_CHECK(topk_group <= 4, "Current routing kernel (with groups) only supports topk_group<=4."); - TORCH_CHECK(topk_group <= n_group, "n_group must not be smaller than topk_group."); - TORCH_CHECK(num_experts % n_group == 0, "num_experts must be divisible by n_group"); + TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3, + "Routing kernel with groups implies DeepSeekV3 routing method."); + TORCH_CHECK(topk_group.has_value(), "if n_group is given, topk_group must be given"); + TORCH_CHECK(num_experts % n_group.value() == 0, "num_experts must be divisible by n_group"); + TORCH_CHECK(top_k <= 8 && top_k > 0, "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."); + TORCH_CHECK(topk_group.value() <= 4 && topk_group.value() > 0, + "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."); + TORCH_CHECK(topk_group.value() <= n_group.value(), "n_group must not be smaller than topk_group."); // This check ensures we have enough experts in the selected groups to handle the top_k routing - TORCH_CHECK(top_k < (topk_group * num_experts / n_group), + TORCH_CHECK(top_k < (topk_group.value() * num_experts / n_group.value()), "top_k must be less than total number of experts in selected groups"); } + else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize + || static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive) + { + TORCH_CHECK(false, "Don't support routing method type Renormalize(Naive)."); + } + else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) + { + TORCH_CHECK(top_k == 1, "Current routing kernel (no groups, Llama4) only supports top_k=1."); + } + TORCH_CHECK(num_experts % 4 == 0, "Routing kernel expects that num_experts must be divisible by 4"); TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k"); @@ -73,7 +89,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit // setup args args.mDtypeElt = btg::Dtype::E4m3; args.routing_logits = routing_logits.data_ptr(); - args.routing_bias = routing_bias.data_ptr(); + args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; args.hidden_states = hidden_states.data_ptr(); args.gemm1_weights = gemm1_weights.data_ptr(); args.output1_scales_scalar = output1_scales_scalar.data_ptr<float>(); @@ -84,11 +100,11 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit args.num_experts = num_experts; args.hidden_size = hidden_states.sizes()[1]; args.top_k = top_k; - args.n_group = n_group; - args.topk_group = topk_group; + args.n_group = n_group.value_or(0); + args.topk_group = topk_group.value_or(0); args.local_expert_offset = local_expert_offset; args.local_num_experts = local_num_experts; - args.routed_scaling_factor = routed_scaling_factor; + args.routed_scaling_factor = routed_scaling_factor.value_or(1.0); args.intermediate_size = intermediate_size; args.mUseRoutingScalesOnInput = use_routing_scales_on_input; @@ -135,15 +151,14 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::Tensor const& routing_logit tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream(routing_logits.get_device()); - routing_runner.run(routing_logits.data_ptr(), routing_bias.data_ptr(), args.num_tokens, args.num_experts, - args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, - args.routed_scaling_factor, expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(), - total_num_padded_tokens.data_ptr<int>(), expanded_idx_to_permuted_idx.data_ptr<int>(), - nullptr /*permuted_idx_to_expanded_idx.data_ptr<int>()*/, permuted_idx_to_token_idx.data_ptr<int>(), - expert_weights.data_ptr(), num_tokens_per_expert.data_ptr<int>(), cta_idx_xy_to_batch_idx.data_ptr<int>(), - cta_idx_xy_to_mn_limit.data_ptr<int>(), num_non_exiting_ctas.data_ptr<int>(), args.mDtypeElt, - use_routing_scales_on_input, false /* use_deep_seek_fp8 */, static_cast<RoutingMethodType>(routing_method_type), - stream); + routing_runner.run(routing_logits.data_ptr(), args.routing_bias, args.num_tokens, args.num_experts, args.top_k, + args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, + expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(), total_num_padded_tokens.data_ptr<int>(), + expanded_idx_to_permuted_idx.data_ptr<int>(), nullptr /*permuted_idx_to_expanded_idx.data_ptr<int>()*/, + permuted_idx_to_token_idx.data_ptr<int>(), expert_weights.data_ptr(), num_tokens_per_expert.data_ptr<int>(), + cta_idx_xy_to_batch_idx.data_ptr<int>(), cta_idx_xy_to_mn_limit.data_ptr<int>(), + num_non_exiting_ctas.data_ptr<int>(), args.mDtypeElt, use_routing_scales_on_input, + false /* use_deep_seek_fp8 */, static_cast<RoutingMethodType>(routing_method_type), stream); // MoE kernel except routing TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8."); @@ -228,7 +243,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) m.def( "fp8_per_tensor_scale_moe_runner(" "Tensor routing_logits," - "Tensor routing_bias," + "Tensor? routing_bias," "Tensor hidden_states," "Tensor gemm1_weights," "Tensor output1_scales_scalar," @@ -237,12 +252,12 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "Tensor output2_scales_scalar," "int num_experts," "int top_k," - "int n_group," - "int topk_group," + "int? n_group," + "int? topk_group," "int intermediate_size," "int local_expert_offset," "int local_num_experts," - "float routed_scaling_factor," + "float? routed_scaling_factor," "bool use_routing_scales_on_input," "int tile_tokens_dim," "int routing_method_type) -> Tensor"); diff --git a/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp b/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp index 56ba59e1ee..f04d46079d 100644 --- a/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp +++ b/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp @@ -27,17 +27,22 @@ namespace torch_ext // This operator applies RMS normalization and RoPE to Q and K tensors in a single CUDA kernel. // The OP performs operations in-place on the input qkv tensor. void fused_qk_norm_rope( - torch::Tensor& qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim] - int64_t num_heads_q, // Number of query heads - int64_t num_heads_k, // Number of key heads - int64_t num_heads_v, // Number of value heads - int64_t head_dim, // Dimension per head - double eps, // Epsilon for RMS normalization - torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] - torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] - double base, // Base for RoPE computation - bool is_neox, // Whether RoPE is applied in Neox style - torch::Tensor& position_ids // Position IDs for RoPE [num_tokens] + torch::Tensor& qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int64_t num_heads_q, // Number of query heads + int64_t num_heads_k, // Number of key heads + int64_t num_heads_v, // Number of value heads + int64_t head_dim, // Dimension per head + double eps, // Epsilon for RMS normalization + torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] + torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] + double base, // Base for RoPE computation + bool is_neox, // Whether RoPE is applied in Neox style + torch::Tensor& position_ids, // Position IDs for RoPE [num_tokens] + // parameters for yarn + double factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn. + double low, // threshold for high frequency + double high, // threshold for low frequency + double attention_factor // attention_factor applied on cos and sin ) { // Input validation @@ -68,7 +73,8 @@ void fused_qk_norm_rope( reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()), static_cast<float>(base), !is_neox, // interleave - reinterpret_cast<int const*>(position_ids.data_ptr()), stream); + reinterpret_cast<int const*>(position_ids.data_ptr()), static_cast<float>(factor), static_cast<float>(low), + static_cast<float>(high), static_cast<float>(attention_factor), stream); } // Register the PyTorch operators @@ -76,7 +82,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float " - "eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids) -> ()"); + "eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float " + "low, float high, float attention_factor) -> ()"); } // Register the CUDA implementation diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 2dc93d5a6c..328cce3d01 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -46,6 +46,7 @@ namespace torch_ext namespace common = tensorrt_llm::common; namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; +using ActivationParams = CUTLASS_MOE_GEMM_NAMESPACE::ActivationParams; using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; // Always use public header as it is just utility functions and types using TmaWarpSpecializedGroupedGemmInput = tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; @@ -93,15 +94,42 @@ public: } }; + template <typename TypeAct> + std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> create_weight_quant_runner() + { + if (isInt8Quant()) + { + return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, uint8_t>>(); + } + else if (isInt4Quant()) + { +#ifdef ENABLE_FP8 + if (mUseW4GroupScaling) + { + return std::make_unique< + kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, TypeAct, TypeAct>>(); + } +#endif + return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t>>(); + } + else + { + C10_THROW_ERROR_FORMATTED(Error, "Unsupported weight quantization type"); + } + } + FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype, - bool use_deepseek_fp8_block_scale, bool use_w4a8_group_scaling, bool use_mxfp8_act_scaling) + bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel, + bool use_mxfp8_act_scaling, bool use_fused_finalize) { mActivationDtype = activation_dtype; mWeightDtype = weight_dtype; mOutputDtype = output_dtype; mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale; - mUseW4A8GroupScaling = use_w4a8_group_scaling; + mUseW4GroupScaling = use_w4_group_scaling; + mUseINT8WoqPerChannel = use_int8_woq_per_channel; mUseMxfp8ActScaling = use_mxfp8_act_scaling; + mUseFusedFinalize = use_fused_finalize; mInnerDimMultiplier = 1; // keep consistent with cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -134,10 +162,9 @@ public: mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype); } - if (isNvfp4Quant()) { - mInnerDimMultiplier = 16; + mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG switch (mActivationDtype) { case c10::ScalarType::Half: @@ -149,45 +176,34 @@ public: default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype); } } -#endif - if (isInt4Quant()) + if (isWFP4A16Quant()) { mInnerDimMultiplier = 2; if (mActivationDtype == c10::ScalarType::Half) { -#ifdef ENABLE_FP8 - if (mUseW4A8GroupScaling) - { - mKernelRunner - = std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>>(); - } - else - { - mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>(); - } -#else - mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>(); -#endif + mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, __nv_fp4_e2m1>>(); } #ifdef ENABLE_BF16 else if (mActivationDtype == c10::ScalarType::BFloat16) { -#ifdef ENABLE_FP8 - if (mUseW4A8GroupScaling) - { - mKernelRunner = std::make_unique< - kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>>(); - } - else - { - mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>(); - } -#else - mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>(); -#endif + mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>>(); } #endif } +#endif + if (isIntWeightOnlyQuant()) + { + if (isInt4Quant()) + { + mInnerDimMultiplier = 2; // 2 INT4 -> 1 INT8 + } + switch (mActivationDtype) + { + case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner<half>(); break; + case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break; + default: C10_THROW_ERROR_FORMATTED(Error, "Unsupported activation type for int-type weight"); + } + } if (!mKernelRunner) { C10_THROW_ERROR_FORMATTED(Error, @@ -196,6 +212,8 @@ public: << ", Output: " << torch::toString(mOutputDtype)); } + mKernelRunner->use_fused_finalize_ = mUseFusedFinalize; + mProfiler = std::make_shared<kernels::GemmProfilerBackend>(); mAllProfiles = mKernelRunner->getTactics(); } @@ -218,7 +236,9 @@ public: torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases, torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales, - torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank, + torch::optional<torch::Tensor> const& input_sf, bool const swizzled_input_sf, + torch::optional<torch::Tensor> const& swiglu_alpha, torch::optional<torch::Tensor> const& swiglu_beta, + torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids) { @@ -259,6 +279,22 @@ public: "fc2_expert_biases should match fc2_expert_weights output shape."); } + if (fc1_expert_biases.has_value() || fc2_expert_biases.has_value()) + { + CHECK_INPUT(fc1_expert_biases.value(), mOutputDtype); + CHECK_INPUT(fc2_expert_biases.value(), mOutputDtype); + TORCH_CHECK(fc1_expert_biases.value().dim() == 2, "fc1_expert_biases must be 2D."); + TORCH_CHECK(fc2_expert_biases.value().dim() == 2, "fc2_expert_biases must be 2D."); + TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc1_expert_biases.value().sizes()[0], + "fc1_expert_weights and fc1_expert_biases must have the same number of experts."); + TORCH_CHECK(fc2_expert_weights.sizes()[0] == fc2_expert_biases.value().sizes()[0], + "fc2_expert_weights and fc2_expert_biases must have the same number of experts."); + TORCH_CHECK(fc1_expert_biases.value().sizes()[1] == fc1_expert_weights.sizes()[1], + "fc1_expert_biases should match fc1_expert_weights output shape."); + TORCH_CHECK(fc2_expert_biases.value().sizes()[1] == fc2_expert_weights.sizes()[1], + "fc2_expert_biases should match fc2_expert_weights output shape."); + } + TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0], "input and token_selected_experts must have the same num tokens."); if (token_final_scales) @@ -271,13 +307,31 @@ public: } TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0], "fc1_expert_weights and fc2_expert_weights must have the same number of experts."); - TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, - "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + + if (mUseINT8WoqPerChannel) + { + // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: + // [num_experts, inter_size, hidden_size] + TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + } + else + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be fc2_expert_weights inter size."); + } int experts_per_token = token_selected_experts.sizes()[1]; int64_t num_rows = input.sizes()[0]; int64_t hidden_size = fc2_expert_weights.sizes()[1]; int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + if (mUseINT8WoqPerChannel) + { + // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: + // [num_experts, inter_size, hidden_size] + hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + inter_size = fc2_expert_weights.sizes()[1]; + } if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant()) { @@ -299,7 +353,32 @@ public: int const num_experts_on_rank = fc2_expert_weights.sizes()[0]; auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank); - auto activation_type = ActivationType::Swiglu; + ActivationType base_activation_type = ActivationType::Swiglu; + if (swiglu_alpha.has_value()) + { + CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); + TORCH_CHECK(swiglu_alpha.value().sizes()[0] == num_experts_on_rank, + "swiglu_alpha must have num_experts_on_rank elements."); + base_activation_type = ActivationType::SwigluBias; + } + if (swiglu_beta.has_value()) + { + CHECK_INPUT(swiglu_beta.value(), at::ScalarType::Float); + TORCH_CHECK(swiglu_beta.value().sizes()[0] == num_experts_on_rank, + "swiglu_beta must have num_experts_on_rank elements."); + base_activation_type = ActivationType::SwigluBias; + } + if (swiglu_limit.has_value()) + { + CHECK_INPUT(swiglu_limit.value(), at::ScalarType::Float); + TORCH_CHECK(swiglu_limit.value().sizes()[0] == num_experts_on_rank, + "swiglu_limit must have num_experts_on_rank elements."); + base_activation_type = ActivationType::SwigluBias; + } + auto activation_params = ActivationParams(base_activation_type, + reinterpret_cast<float const*>(swiglu_alpha.has_value() ? swiglu_alpha.value().const_data_ptr() : nullptr), + reinterpret_cast<float const*>(swiglu_beta.has_value() ? swiglu_beta.value().const_data_ptr() : nullptr), + reinterpret_cast<float const*>(swiglu_limit.has_value() ? swiglu_limit.value().const_data_ptr() : nullptr)); setRunnerProfiles(profile_ids); @@ -309,7 +388,7 @@ public: auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype)); WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, - static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode); + static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode); auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); kernels::MoeMinLatencyParams min_latency_params{}; @@ -318,12 +397,12 @@ public: ::tensorrt_llm::kernels::LoraParams lora_params{}; #ifdef USING_OSS_CUTLASS_MOE_GEMM mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr()) : nullptr, fc1_expert_weights.const_data_ptr(), - fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type, + fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc2_expert_weights.const_data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), @@ -332,16 +411,16 @@ public: mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); #else mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr()) : nullptr, fc1_expert_weights.const_data_ptr(), - fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type, + fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc2_expert_weights.const_data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), - static_cast<char*>(workspace_info.workspace), output.data_ptr(), + static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); #endif @@ -354,7 +433,9 @@ public: torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases, torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales, - torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank, + torch::optional<torch::Tensor> const& input_sf, bool const swizzled_input_sf, + torch::optional<torch::Tensor> const& swiglu_alpha, torch::optional<torch::Tensor> const& swiglu_beta, + torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids) { @@ -420,7 +501,32 @@ public: auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank); - auto activation_type = ActivationType::Swiglu; + ActivationType base_activation_type = ActivationType::Swiglu; + if (swiglu_alpha.has_value()) + { + CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); + TORCH_CHECK(swiglu_alpha.value().sizes()[0] == num_experts_on_rank, + "swiglu_alpha must have num_experts_on_rank elements."); + base_activation_type = ActivationType::SwigluBias; + } + if (swiglu_beta.has_value()) + { + CHECK_INPUT(swiglu_beta.value(), at::ScalarType::Float); + TORCH_CHECK(swiglu_beta.value().sizes()[0] == num_experts_on_rank, + "swiglu_beta must have num_experts_on_rank elements."); + base_activation_type = ActivationType::SwigluBias; + } + if (swiglu_limit.has_value()) + { + CHECK_INPUT(swiglu_limit.value(), at::ScalarType::Float); + TORCH_CHECK(swiglu_limit.value().sizes()[0] == num_experts_on_rank, + "swiglu_limit must have num_experts_on_rank elements."); + base_activation_type = ActivationType::SwigluBias; + } + auto activation_params = ActivationParams(base_activation_type, + reinterpret_cast<float const*>(swiglu_alpha.has_value() ? swiglu_alpha.value().const_data_ptr() : nullptr), + reinterpret_cast<float const*>(swiglu_beta.has_value() ? swiglu_beta.value().const_data_ptr() : nullptr), + reinterpret_cast<float const*>(swiglu_limit.has_value() ? swiglu_limit.value().const_data_ptr() : nullptr)); setRunnerProfiles(profile_ids); @@ -440,7 +546,7 @@ public: min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr()); WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, - static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode); + static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode); auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); @@ -448,12 +554,12 @@ public: ::tensorrt_llm::kernels::LoraParams lora_params{}; #ifdef USING_OSS_CUTLASS_MOE_GEMM mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr()) : nullptr, fc1_expert_weights.const_data_ptr(), - fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type, + fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc2_expert_weights.const_data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), @@ -462,16 +568,16 @@ public: mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); #else mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr()) : nullptr, fc1_expert_weights.const_data_ptr(), - fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type, + fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc2_expert_weights.const_data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), - static_cast<char*>(workspace_info.workspace), output.data_ptr(), + static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); #endif @@ -485,6 +591,7 @@ public: return mAllProfiles.size(); } + // TODO Update this to be able to tell if we are profiling swiglu bias void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size, @@ -501,9 +608,20 @@ public: } int64_t const num_rows = input.sizes()[0]; - int64_t const hidden_size = fc2_expert_weights.sizes()[1]; - int64_t const inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; - int64_t const group_size = isInt4Quant() ? 128 : -1; + int64_t hidden_size = fc2_expert_weights.sizes()[1]; + int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + if (mUseINT8WoqPerChannel) + { + // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: + // [num_experts, inter_size, hidden_size] + hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + inter_size = fc2_expert_weights.sizes()[1]; + } + int64_t const group_size_ + = isInt4Quant() ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size : -1; + int64_t const group_size = isWFP4A16Quant() + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size + : group_size_; int const num_experts = static_cast<int>(fc2_expert_weights.sizes()[0] * ep_size); // Get specific profile configs according to the profile_id. @@ -530,7 +648,8 @@ public: bool const USE_BIAS = fc1_expert_biases.has_value() || fc2_expert_biases.has_value(); bool const USE_LORA = false; - auto activation_dtype = mUseW4A8GroupScaling ? at::ScalarType::Float8_e4m3fn : mActivationDtype; + auto activation_dtype + = (mUseW4GroupScaling && !isWFP4A16Quant()) ? at::ScalarType::Float8_e4m3fn : mActivationDtype; activation_dtype = isNvfp4Quant() ? at::ScalarType::Long : activation_dtype; #ifdef USING_OSS_CUTLASS_MOE_GEMM mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, @@ -579,8 +698,10 @@ private: char* mProfileWorkspace = nullptr; bool mUseDeepSeekFP8BlockScaling = false; - bool mUseW4A8GroupScaling = false; + bool mUseW4GroupScaling = false; + bool mUseINT8WoqPerChannel = false; bool mUseMxfp8ActScaling = false; + bool mUseFusedFinalize = true; using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; std::vector<Profile> mAllProfiles; @@ -628,7 +749,7 @@ private: { size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, - min_latency_mode, mUseW4A8GroupScaling); + min_latency_mode, mUseW4GroupScaling); size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int); std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size}; @@ -768,8 +889,7 @@ private: && fc1_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( - hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX) - * TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX, + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX), "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // " "block_scale_vector_size)"); TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)"); @@ -794,7 +914,6 @@ private: TORCH_CHECK(false, "MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm"); #endif } - else if (isNvfp4Quant()) { TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization"); @@ -867,27 +986,57 @@ private: return kernels::QuantParams::FP8BlockScaling( static_cast<float const*>(fc1_scales.data_ptr()), static_cast<float const*>(fc2_scales.data_ptr())); } - else if (isInt4Quant()) + else if (isWFP4A16Quant()) { - TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for INT4 quantization"); - TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for INT4 quantization"); + TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization"); + TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for W4A16 quantization"); + auto& fc1_weight_scales = quant_scales.value()[0]; auto& fc2_weight_scales = quant_scales.value()[1]; - auto& fc1_act_scales = quant_scales.value()[2]; - auto& fc2_act_scales = quant_scales.value()[3]; - auto& fc1_weight_zeros = quant_scales.value()[4]; - auto& fc2_weight_zeros = quant_scales.value()[5]; - auto& fc1_alpha = quant_scales.value()[6]; - auto& fc2_alpha = quant_scales.value()[7]; - int group_size = 128; + int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size; return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()), - static_cast<void const*>(fc2_weight_scales.data_ptr()), - static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr), - static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr), - static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr), - static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr), - static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr), - static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr)); + static_cast<void const*>(fc2_weight_scales.data_ptr()), nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr); + } + else if (isIntWeightOnlyQuant()) + { + TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization"); + if (mUseINT8WoqPerChannel) + { + TORCH_CHECK( + quant_scales.value().size() == 2, "Expecting 2 quant scales for INT8 weight only quantization"); + auto& fc1_weight_scales = quant_scales.value()[0]; + auto& fc2_weight_scales = quant_scales.value()[1]; + return kernels::QuantParams::Int(static_cast<float const*>(fc1_weight_scales.data_ptr()), + static_cast<float const*>(fc2_weight_scales.data_ptr())); + } + else if (isInt4Quant() && mUseW4GroupScaling) + { + TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization"); + + auto& fc1_weight_scales = quant_scales.value()[0]; + auto& fc2_weight_scales = quant_scales.value()[1]; + auto& fc1_act_scales = quant_scales.value()[2]; + auto& fc2_act_scales = quant_scales.value()[3]; + auto& fc1_weight_zeros = quant_scales.value()[4]; + auto& fc2_weight_zeros = quant_scales.value()[5]; + auto& fc1_alpha = quant_scales.value()[6]; + auto& fc2_alpha = quant_scales.value()[7]; + int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size; + return kernels::QuantParams::GroupWise(group_size, + static_cast<void const*>(fc1_weight_scales.data_ptr()), + static_cast<void const*>(fc2_weight_scales.data_ptr()), + static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr), + static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr), + static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr), + static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr), + static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr), + static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr)); + } + else + { + TORCH_CHECK(false, "Unsupported weight only quantization"); + } } else { @@ -907,6 +1056,16 @@ private: && mActivationDtype != c10::ScalarType::Float8_e4m3fn; // FP8 activation does not use FP4 } + bool isWFP4A16Quant() const + { + return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte; + } + + bool isInt8Quant() const + { + return mWeightDtype == c10::ScalarType::Char; + } + bool isInt4Quant() const { return mWeightDtype == c10::ScalarType::QUInt4x2; @@ -917,6 +1076,11 @@ private: return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant(); } + bool isIntWeightOnlyQuant() const + { + return isInt8Quant() || isInt4Quant(); + } + bool isWMxfp4AFp8Quant() const { return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long @@ -935,7 +1099,7 @@ private: TORCH_LIBRARY(trtllm, m) { m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner") - .def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool>()) + .def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool, bool>()) .def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile) .def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum) .def("run_moe", &torch_ext::FusedMoeRunner::runMoe) diff --git a/cpp/tensorrt_llm/thop/moeUtilOp.cpp b/cpp/tensorrt_llm/thop/moeUtilOp.cpp index d939bcd07f..18e0149315 100644 --- a/cpp/tensorrt_llm/thop/moeUtilOp.cpp +++ b/cpp/tensorrt_llm/thop/moeUtilOp.cpp @@ -83,7 +83,7 @@ void runPermute(void const* input_activations_void, void const* input_sf_void, i reinterpret_cast<ExpandedActivationsType*>(permuted_data_), token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token, num_experts_per_node, quant_params, /*use_per_expert_act_scale*/ false, expert_first_token_offset_, - /* fc1_fp4_act_scale_ */ nullptr, input_sf, /* prequant_scales */ nullptr, stream); + /* fc1_fp4_act_scale_ */ nullptr, input_sf, true, /* prequant_scales */ nullptr, stream); sync_check_cuda_error(stream); } diff --git a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp new file mode 100644 index 0000000000..39459c8da4 --- /dev/null +++ b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp @@ -0,0 +1,508 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h" +#include "tensorrt_llm/runtime/torchUtils.h" +#include "tensorrt_llm/thop/thUtils.h" +#include <ATen/cuda/EmptyTensor.h> +#include <ATen/ops/index_select.h> +#include <c10/util/Exception.h> +#include <cstdint> +#include <memory> +#include <optional> + +namespace torch_ext +{ +namespace btg = batchedGemm::trtllm::gen; +using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType; +using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; + +torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::Tensor const& routing_logits, + torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states, + std::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights, + torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias, + std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta, + std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights, + torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias, + std::optional<torch::Tensor> const& output1_scale_scalar, + std::optional<torch::Tensor> const& output1_scale_gate_scalar, + std::optional<torch::Tensor> const& output2_scale_scalar, int64_t const num_experts, int64_t const top_k, + std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size, + std::optional<int64_t> const hidden_size_output, int64_t const local_expert_offset, int64_t const local_num_experts, + std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, + btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t moeConfigIndex) +{ + auto const sm = tensorrt_llm::common::getSMVersion(); + TORCH_CHECK(sm == 100, "Only SM100 is supported by FP4 block scale MOE"); + TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64, + "tile_tokens_dim must be 8, 16, 32, 64"); + TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float + || routing_logits.scalar_type() == at::ScalarType::BFloat16, + "routing_logits must be float or bfloat16."); + TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D."); + TORCH_CHECK( + routing_logits.sizes()[0] == hidden_states.sizes()[0], "routing_logits dim0 must match hidden_states dim0."); + TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); + if (routing_bias.has_value()) + { + TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16, "routing_bias must be bfloat16."); + TORCH_CHECK(routing_bias.value().dim() == 1, "routing_bias must be 1D."); + TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); + } + + if (n_group.has_value() && n_group.value() != 0) + { + TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3, + "Routing kernel with groups implies DeepSeekV3 routing method."); + TORCH_CHECK(topk_group.has_value(), "if n_group is given, topk_group must be given"); + TORCH_CHECK(num_experts % n_group.value() == 0, "num_experts must be divisible by n_group"); + TORCH_CHECK(top_k <= 8 && top_k > 0, "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."); + TORCH_CHECK(topk_group.value() <= 4 && topk_group.value() > 0, + "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."); + TORCH_CHECK(topk_group.value() <= n_group.value(), "n_group must not be smaller than topk_group."); + // This check ensures we have enough experts in the selected groups to handle the top_k routing + TORCH_CHECK(top_k < (topk_group.value() * num_experts / n_group.value()), + "top_k must be less than total number of experts in selected groups"); + } + else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize + || static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive) + { + TORCH_CHECK(top_k <= 8 && top_k > 0, + "Current routing kernel (no groups, renormalize) only supports top_k<=8 && top_k>0."); + } + + TORCH_CHECK(num_experts % 4 == 0, "Routing kernel expects that num_experts must be divisible by 4"); + TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k"); + + tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoERunnerArgs args; + tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoEWorkspace workspace; + + // setup args + args.mDtypeElt = dtype; + args.routing_logits = routing_logits.data_ptr(); + args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; + args.hidden_states = hidden_states.data_ptr(); + args.hidden_states_scale = hidden_states_scale.has_value() ? hidden_states_scale.value().data_ptr() : nullptr; + args.gemm1_weights = gemm1_weights.data_ptr(); + args.gemm1_weights_scale = gemm1_weights_scale.data_ptr(); + args.gemm2_weights = gemm2_weights.data_ptr(); + args.gemm2_weights_scale = gemm2_weights_scale.data_ptr(); + args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr<float>() : nullptr; + args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr<float>() : nullptr; + args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr<float>() : nullptr; + args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr; + args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr; + args.output1_scales_scalar + = output1_scale_scalar.has_value() ? output1_scale_scalar.value().data_ptr<float>() : nullptr; + args.output1_scales_gate_scalar + = output1_scale_gate_scalar.has_value() ? output1_scale_gate_scalar.value().data_ptr<float>() : nullptr; + args.output2_scales_scalar + = output2_scale_scalar.has_value() ? output2_scale_scalar.value().data_ptr<float>() : nullptr; + args.num_tokens = hidden_states.sizes()[0]; + args.num_experts = num_experts; + args.hidden_size = hidden_states.sizes()[1]; + args.hidden_size_output = hidden_size_output.value_or(args.hidden_size); + args.top_k = top_k; + args.n_group = n_group.value_or(0); + args.topk_group = topk_group.value_or(0); + args.local_expert_offset = local_expert_offset; + args.local_num_experts = local_num_experts; + args.routed_scaling_factor = routed_scaling_factor.value_or(1.0); + args.intermediate_size = intermediate_size; + + // allocate workspace for routing kernel + at::Tensor num_tokens_per_expert + = at::detail::empty_cuda({num_experts}, at::ScalarType::Int, routing_logits.device(), std::nullopt); + int32_t max_num_padded_tokens + = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::getMaxPermutedPaddedCount( + args.num_tokens, top_k, num_experts, tile_tokens_dim); + at::Tensor total_num_padded_tokens + = at::empty({}, at::TensorOptions().device(routing_logits.device()).dtype(at::ScalarType::Int)); + at::Tensor expanded_idx_to_permuted_idx = at::detail::empty_cuda( + {args.num_tokens * args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt); + + at::Tensor permuted_idx_to_token_idx + = at::detail::empty_cuda({max_num_padded_tokens}, at::ScalarType::Int, routing_logits.device(), std::nullopt); + at::Tensor expert_weights = at::detail::empty_cuda( + {args.num_tokens, args.top_k}, at::ScalarType::BFloat16, routing_logits.device(), std::nullopt); + at::Tensor expert_indexes = at::detail::empty_cuda( + {args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt); + at::Tensor expert_count_histogram = at::detail::empty_cuda({2 * 256}, + at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts + routing_logits.device(), std::nullopt); + + int32_t const sf_block_size = 32; + // allocate workspace for activation/gemm/finalize kernels + auto const gemm1_output_type + = dtype == btg::Dtype::Bfloat16 ? at::ScalarType::BFloat16 : at::ScalarType::Float8_e4m3fn; + at::Tensor gemm1_output = at::detail::empty_cuda( + {max_num_padded_tokens, intermediate_size}, gemm1_output_type, hidden_states.device(), std::nullopt); + + std::optional<at::Tensor> gemm1_output_scale; + if (dtype == btg::Dtype::MxE4m3) + { + int64_t sf_size + = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens, intermediate_size / sf_block_size); + gemm1_output_scale = at::detail::empty_cuda({sf_size}, SF_DTYPE, hidden_states.device(), std::nullopt); + } + + at::Tensor gemm2_output = at::detail::empty_cuda( + {max_num_padded_tokens, args.hidden_size}, at::ScalarType::BFloat16, hidden_states.device(), std::nullopt); + + int32_t max_num_ctas = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::getMaxNumCtasInBatchDim( + args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); + at::Tensor cta_idx_xy_to_batch_idx + = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, routing_logits.device(), std::nullopt); + at::Tensor cta_idx_xy_to_mn_limit + = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, routing_logits.device(), std::nullopt); + at::Tensor num_non_exiting_ctas + = at::empty({}, at::TensorOptions().device(routing_logits.device()).dtype(at::ScalarType::Int)); + + // FIXME: check shape + TORCH_CHECK(dtype == btg::Dtype::MxE4m3 || dtype == btg::Dtype::Bfloat16 || dtype == btg::Dtype::E4m3, + "dtype must be MxE4m3 or Bfloat16 or E4m3."); + if (dtype == btg::Dtype::MxE4m3) + { + TORCH_CHECK(hidden_states_scale.has_value(), "hidden_states_scale must be provided for MxE4m3."); + } + else + { + TORCH_CHECK( + !hidden_states_scale.has_value(), "hidden_states_scale must not be provided for Bfloat16 and E4m3."); + } + + // + // TopK routing + // + + tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); + auto const& stream = at::cuda::getCurrentCUDAStream(routing_logits.get_device()); + routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, + args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, + expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(), total_num_padded_tokens.data_ptr<int>(), + expanded_idx_to_permuted_idx.data_ptr<int>(), nullptr, /*permuted_idx_to_expanded_idx.data_ptr<int>(),*/ + permuted_idx_to_token_idx.data_ptr<int>(), expert_weights.data_ptr(), num_tokens_per_expert.data_ptr<int>(), + cta_idx_xy_to_batch_idx.data_ptr<int>(), cta_idx_xy_to_mn_limit.data_ptr<int>(), + num_non_exiting_ctas.data_ptr<int>(), args.mDtypeElt, false /* use_routing_scales_on_input */, + false /* use_deep_seek_fp8 */, static_cast<RoutingMethodType>(routing_method_type), stream); + + // + // FC13 (gemm1) + FC2 (gemm2) + // + + if (dtype == btg::Dtype::MxE4m3 || dtype == btg::Dtype::E4m3) + { + TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, + "hidden_states must be Float8_e4m3fn, got %s.", c10::toString(hidden_states.scalar_type())); + } + else + { + TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::BFloat16, "hidden_states must be BFloat16, got %s.", + c10::toString(hidden_states.scalar_type())); + } + if (dtype == btg::Dtype::MxE4m3) + { + TORCH_CHECK(hidden_states_scale->scalar_type() == SF_DTYPE, "hidden_states_scale must be UInt8, got %s.", + c10::toString(hidden_states_scale->scalar_type())); + + TORCH_CHECK(hidden_states_scale->dim() == 1, "hidden_states_scale must be 1D."); + TORCH_CHECK(hidden_states_scale->sizes()[0] + == tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, args.hidden_size / sf_block_size), + "hidden_states_scale has incorrect size"); + } + + TORCH_CHECK(gemm1_weights.scalar_type() == FLOAT4_E2M1X2, "gemm1_weights must be byte, got %s.", + c10::toString(gemm1_weights.scalar_type())); + + TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D."); + TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even."); + TORCH_CHECK(2 * intermediate_size == gemm1_weights.sizes()[1], "intermediate_size has incorrect dim 1."); + // The actual shape of the weights[2] is 2 times larger than and hidden_states[1] + // due to the fact that 2 e2m1 are packed into 1 byte for FP4 weights. + TORCH_CHECK(gemm1_weights.sizes()[2] * 2 == hidden_states.sizes()[1], + "the third dimension of weights must be equal to hidden_size."); + + TORCH_CHECK(gemm1_weights_scale.scalar_type() == SF_DTYPE, "gemm1_weights_scale must be UInt8, got %s.", + c10::toString(gemm1_weights_scale.scalar_type())); + + TORCH_CHECK(gemm1_weights_scale.dim() == 3, "gemm1_weights_scale must be 3D."); + TORCH_CHECK(gemm1_weights_scale.sizes()[0] == local_num_experts, "gemm1_weights_scale has incorrect dim 0."); + TORCH_CHECK(intermediate_size % sf_block_size == 0, "the second dimension of weights must be a multiple of 32."); + TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1."); + TORCH_CHECK( + gemm1_weights_scale.sizes()[2] == args.hidden_size / sf_block_size, "gemm1_weights_scale has incorrect dim 2."); + + if (gemm1_bias.has_value()) + { + TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float, "gemm1_bias must be float, got %s.", + c10::toString(gemm1_bias.value().scalar_type())); + TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D."); + TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts, "gemm1_bias has incorrect dim 0."); + TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size, "gemm1_bias has incorrect dim 1."); + } + + if (gemm1_alpha.has_value()) + { + TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float, "gemm1_alpha must be float, got %s.", + c10::toString(gemm1_alpha.value().scalar_type())); + TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D."); + TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts, "gemm1_alpha has incorrect dim 0."); + } + if (gemm1_beta.has_value()) + { + TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float, "gemm1_beta must be float, got %s.", + c10::toString(gemm1_beta.value().scalar_type())); + TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D."); + TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts, "gemm1_beta has incorrect dim 0."); + } + if (gemm1_clamp_limit.has_value()) + { + TORCH_CHECK(gemm1_clamp_limit.value().scalar_type() == at::ScalarType::Float, + "gemm1_clamp_limit must be float, got %s.", c10::toString(gemm1_clamp_limit.value().scalar_type())); + TORCH_CHECK(gemm1_clamp_limit.value().dim() == 1, "gemm1_clamp_limit must be 1D."); + TORCH_CHECK( + gemm1_clamp_limit.value().sizes()[0] == local_num_experts, "gemm1_clamp_limit has incorrect dim 0."); + } + + TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte, got %s.", + c10::toString(gemm2_weights.scalar_type())); + + TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D."); + // / 2 to compensate for the fact that we pack 2 e2m1 into 1 byte. + TORCH_CHECK(gemm2_weights.sizes()[2] == intermediate_size / 2, + "the third dimension of weights must be equal to intermediate_size."); + + TORCH_CHECK(gemm2_weights_scale.scalar_type() == SF_DTYPE, "gemm2_weights_scale must be UInt8, got %s.", + c10::toString(gemm2_weights_scale.scalar_type())); + + TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D."); + TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0."); + TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1."); + TORCH_CHECK(gemm2_weights_scale.sizes()[2] == intermediate_size / sf_block_size, + "gemm2_weights_scale has incorrect dim 2."); + + if (gemm2_bias.has_value()) + { + TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.", + c10::toString(gemm2_bias.value().scalar_type())); + TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D."); + TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0."); + TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.hidden_size, "gemm2_bias has incorrect dim 1."); + } + + if (dtype == btg::Dtype::E4m3) + { + TORCH_CHECK(output1_scale_scalar.has_value(), "output1_scale_scalar must be provided for MxE4m3."); + TORCH_CHECK(output1_scale_gate_scalar.has_value(), "output1_scale_gate_scalar must be provided for MxE4m3."); + TORCH_CHECK(output2_scale_scalar.has_value(), "output2_scale_scalar must be provided for MxE4m3."); + + TORCH_CHECK( + output1_scale_scalar->scalar_type() == at::ScalarType::Float, "output1_scales_scalar must be float."); + TORCH_CHECK(output1_scale_scalar->dim() == 1, "output1_scales_scalar must be 1D."); + TORCH_CHECK( + output1_scale_scalar->sizes()[0] == local_num_experts, "output1_scales_scalar has incorrect dim 0."); + + TORCH_CHECK(output1_scale_gate_scalar->scalar_type() == at::ScalarType::Float, + "output1_scales_gate_scalar must be float."); + TORCH_CHECK(output1_scale_gate_scalar->dim() == 1, "output1_scales_gate_scalar must be 1D."); + TORCH_CHECK(output1_scale_gate_scalar->sizes()[0] == local_num_experts, + "output1_scales_gate_scalar has incorrect dim 0."); + + TORCH_CHECK( + output2_scale_scalar->scalar_type() == at::ScalarType::Float, "output2_scales_scalar must be float."); + TORCH_CHECK(output2_scale_scalar->dim() == 1, "output2_scales_scalar must be 1D."); + TORCH_CHECK( + output2_scale_scalar->sizes()[0] == local_num_experts, "output2_scales_scalar has incorrect dim 0."); + } + + // allocate output + at::Tensor output = at::detail::empty_cuda({args.num_tokens, args.hidden_size_output.value()}, + at::ScalarType::BFloat16, hidden_states.device(), std::nullopt); + + // setup workspace + workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr<int>(); + workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.ProjUpTileN = tile_tokens_dim; + workspace.routing_expert_indexes = expert_indexes.data_ptr<int>(); + workspace.permuted_idx_size = total_num_padded_tokens.data_ptr<int>(); + workspace.expanded_idx_to_permuted_idx + = expanded_idx_to_permuted_idx.data_ptr<int>(); // Needed by permute/finalize kernels + workspace.permuted_idx_to_token_idx = permuted_idx_to_token_idx.data_ptr<int>(); // Needed by permuteGemm1 kernel + workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel + + workspace.cta_idx_xy_to_batch_idx = cta_idx_xy_to_batch_idx.data_ptr<int>(); + workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr<int>(); + workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr<int>(); + + // gemm1 intermediate ws + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale + = gemm1_output_scale.has_value() ? reinterpret_cast<float*>(gemm1_output_scale->data_ptr()) : nullptr; + + // gemm2 intermediate ws + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; + args.output = output.data_ptr(); + args.output_scale = nullptr; + + auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); + at::Tensor workspace_fc1 = at::detail::empty_cuda( + {std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt); + at::Tensor workspace_fc2 = at::detail::empty_cuda( + {std::get<1>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt); + workspace.bmm1_workspace = workspace_fc1.data_ptr(); + workspace.bmm2_workspace = workspace_fc2.data_ptr(); + auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); + moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); + return output; +} + +// Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow +// use with the torch workflow autotuner class. +class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder +{ + +public: + explicit Bf16MxE2m1BlockScaleMoeRunner(int64_t tileTokensDim, int64_t actType) + : mTileTokensDim(tileTokensDim) + { + mRunner = std::make_unique<RunnerType>(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, mTileTokensDim, + static_cast<tensorrt_llm::kernels::ActType>(actType)); + } + + [[nodiscard]] std::vector<int64_t> getValidConfigs( + int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const + { + return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens); + } + + // BF16 run does not use hidden_states_scale + [[nodiscard]] torch::Tensor run(torch::Tensor const& routing_logits, + std::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states, + torch::Tensor const& gemm1_weights, torch::Tensor const& gemm1_weights_scale, + std::optional<torch::Tensor> const& gemm1_bias, std::optional<torch::Tensor> const& gemm1_alpha, + std::optional<torch::Tensor> const& gemm1_beta, std::optional<torch::Tensor> const& gemm1_clamp_limit, + torch::Tensor const& gemm2_weights, torch::Tensor const& gemm2_weights_scale, + std::optional<torch::Tensor> const& gemm2_bias, int64_t num_experts, int64_t top_k, + std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, std::optional<double> routed_scaling_factor, + int64_t routing_method_type, int64_t moeConfigIndex) + { + // Autotuner has requested a default or 'fallback' config index + if (moeConfigIndex == -1) + { + auto const num_tokens = hidden_states.sizes()[0]; + auto const hidden_size = hidden_states.sizes()[1]; + + moeConfigIndex = mRunner->getDefaultValidConfigIndex( + top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); + } + + return dtype_mxe2m1_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, std::nullopt, + gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, + gemm2_weights_scale, gemm2_bias, std::nullopt, std::nullopt, std::nullopt, num_experts, top_k, n_group, + topk_group, intermediate_size, std::nullopt, local_expert_offset, local_num_experts, routed_scaling_factor, + mTileTokensDim, routing_method_type, mDtypeAct, *mRunner, moeConfigIndex); + } + +private: + using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; + + std::unique_ptr<RunnerType> mRunner; + + btg::Dtype mDtypeAct{btg::Dtype::Bfloat16}; + btg::Dtype mDtypeWeights{btg::Dtype::MxE2m1}; + bool mUseDeepSeekFp8{false}; + int64_t mTileTokensDim; +}; + +class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder +{ + +public: + explicit MxE4m3MxE2m1BlockScaleMoeRunner(int64_t tileTokensDim, int64_t actType, bool isMxFp8) + : mDtypeAct(isMxFp8 ? btg::Dtype::MxE4m3 : btg::Dtype::E4m3) + , mTileTokensDim(tileTokensDim) + { + mRunner = std::make_unique<RunnerType>(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, mTileTokensDim, + static_cast<tensorrt_llm::kernels::ActType>(actType)); + } + + [[nodiscard]] std::vector<int64_t> getValidConfigs( + int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const + { + return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens); + } + + [[nodiscard]] torch::Tensor run(torch::Tensor const& routing_logits, + std::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states, + std::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights, + torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias, + std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta, + std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights, + torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias, + std::optional<torch::Tensor> const& output1_scale_scalar, + std::optional<torch::Tensor> const& output1_scale_gate_scalar, + std::optional<torch::Tensor> const& output2_scale_scalar, int64_t num_experts, int64_t top_k, + std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t intermediate_size, + std::optional<int64_t> const hidden_size_output, int64_t local_expert_offset, int64_t local_num_experts, + std::optional<double> routed_scaling_factor, int64_t routing_method_type, int64_t moeConfigIndex) + { + // Autotuner has requested a default or 'fallback' config index + if (moeConfigIndex == -1) + { + auto const num_tokens = hidden_states.sizes()[0]; + auto const hidden_size = hidden_states.sizes()[1]; + + moeConfigIndex = mRunner->getDefaultValidConfigIndex( + top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); + } + + return dtype_mxe2m1_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale, + gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, + gemm2_weights_scale, gemm2_bias, output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, + num_experts, top_k, n_group, topk_group, intermediate_size, hidden_size_output, local_expert_offset, + local_num_experts, routed_scaling_factor, mTileTokensDim, routing_method_type, mDtypeAct, *mRunner, + moeConfigIndex); + } + +private: + using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; + + std::unique_ptr<RunnerType> mRunner; + + btg::Dtype mDtypeAct{btg::Dtype::MxE4m3}; + btg::Dtype mDtypeWeights{btg::Dtype::MxE2m1}; + bool mUseDeepSeekFp8{false}; + int64_t mTileTokensDim; +}; + +} // namespace torch_ext + +// Accepts CUDA tensor only +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.class_<torch_ext::Bf16MxE2m1BlockScaleMoeRunner>("Bf16MxE2m1BlockScaleMoERunner") + .def(torch::init<int64_t, int64_t>()) + .def("get_valid_configs", &torch_ext::Bf16MxE2m1BlockScaleMoeRunner::getValidConfigs) + .def("run_moe", &torch_ext::Bf16MxE2m1BlockScaleMoeRunner::run); + + m.class_<torch_ext::MxE4m3MxE2m1BlockScaleMoeRunner>("MxE4m3MxE2m1BlockScaleMoERunner") + .def(torch::init<int64_t, int64_t, bool>()) + .def("get_valid_configs", &torch_ext::MxE4m3MxE2m1BlockScaleMoeRunner::getValidConfigs) + .def("run_moe", &torch_ext::MxE4m3MxE2m1BlockScaleMoeRunner::run); +} diff --git a/cpp/tensorrt_llm/thop/mxFp8Quantize.cpp b/cpp/tensorrt_llm/thop/mxFp8Quantize.cpp new file mode 100644 index 0000000000..ba651f2886 --- /dev/null +++ b/cpp/tensorrt_llm/thop/mxFp8Quantize.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/thop/thUtils.h" + +#include <ATen/cuda/EmptyTensor.h> + +#include <cuda_fp16.h> + +#include <cstdint> + +namespace torch_ext +{ +// self: [M, K], fp16/bf16/fp8_quantized +// mxfp8: sfVecSize = 32 +// alignment: sfVecSize +// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in linear layout. +// See QuantizationSFLayout enum for more details about the two layouts. +// returns self_mxfp8, self_block_scale_factors +// self_mxfp8: [M, K], Float8_e4m3fn +// self_block_scale_factors: ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE +std::tuple<at::Tensor, at::Tensor> mxfp8_quantize( + at::Tensor const& self, bool isSfSwizzledLayout, int64_t alignment = 32) +{ + CHECK_TH_CUDA(self); + CHECK_CONTIGUOUS(self); + + // Fixed SF_VEC_SIZE as 32 + static constexpr int SF_VEC_SIZE = 32; + TORCH_CHECK(alignment % SF_VEC_SIZE == 0, "alignment must be divisible by SF_VEC_SIZE = 32"); + + auto const& inputShape = self.sizes(); + auto const& rank = inputShape.size(); + + TORCH_CHECK(rank >= 2, "Input should be >=2D tensor."); + int64_t m = 1; + for (size_t i = 0; i < rank - 1; i++) + { + m *= inputShape[i]; + } + auto const k = inputShape[rank - 1]; + TORCH_CHECK(k % SF_VEC_SIZE == 0, "k must be divisible by SF_VEC_SIZE = 32"); + auto const padded_k = ((k + alignment - 1) / alignment) * alignment; + + std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end()); + outputShape[rank - 1] = padded_k; + + at::Tensor valMxFP8 + = at::detail::empty_cuda(outputShape, at::ScalarType::Float8_e4m3fn, self.device(), /* stride */ std::nullopt); + + int64_t SFSize = isSfSwizzledLayout ? tensorrt_llm::computeSwizzledLayoutSFSize(m, padded_k / SF_VEC_SIZE) + : tensorrt_llm::computeLinearLayoutSFSize(m, padded_k / SF_VEC_SIZE); + + at::Tensor scaleFP8SF + = at::detail::empty_cuda({SFSize}, SF_DTYPE, self.device(), /* stride */ std::nullopt); // 1D tensor + + const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); + + auto const layout = isSfSwizzledLayout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED + : tensorrt_llm::QuantizationSFLayout::LINEAR; + +#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \ + tensorrt_llm::kernels::invokeMxFP8Quantization(1, m, k, padded_k, reinterpret_cast<T*>(self.data_ptr()), \ + reinterpret_cast<int64_t*>(valMxFP8.data_ptr()), reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), layout, \ + mMultiProcessorCount, at::cuda::getCurrentCUDAStream(self.get_device())); + + if (self.scalar_type() == at::ScalarType::Half) + { + LAUNCH_MXFP8_QUANTIZE_KERNEL(half) + } + else if (self.scalar_type() == at::ScalarType::BFloat16) + { +#ifdef ENABLE_BF16 + LAUNCH_MXFP8_QUANTIZE_KERNEL(__nv_bfloat16) +#else + C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled to quantize an bf16 tensor to mxfp8."); +#endif + } + else + { + C10_THROW_ERROR(NotImplementedError, "mxfp8_quantize only supports input tensor with dtypes fp16/bf16."); + } + +#undef LAUNCH_MXFP8_QUANTIZE_KERNEL + + return {valMxFP8, scaleFP8SF}; +} +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "mxfp8_quantize(Tensor input, bool swizzedLayout=True, int alignment=32) " + "-> (Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("mxfp8_quantize", &torch_ext::mxfp8_quantize); +} diff --git a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp index 4bf061b5b9..b6feba15e6 100644 --- a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp +++ b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp @@ -349,6 +349,55 @@ Tensor pack_int8_tensor_to_packed_int4(Tensor weight) return packed_weight; } +Tensor mxfp4_dequantize_unswizzled(Tensor weight, Tensor scale, int64_t group_size) +{ + // weight (n, k / 2) + // scale (n, k / group_size) + + CHECK_CPU(weight); + CHECK_CPU(scale); + CHECK_CONTIGUOUS(weight); + CHECK_CONTIGUOUS(scale); + TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor"); + TORCH_CHECK(weight.dtype() == torch::kUInt8, "Weight must be a packed int8 tensor"); + TORCH_CHECK(scale.dtype() == torch::kUInt8, "Scale must be a int8 tensor"); + + TORCH_CHECK(weight.size(0) == scale.size(0)) + TORCH_CHECK(weight.size(1) * 2 == scale.size(1) * group_size) + + uint8_t* weight_packed_ptr = get_ptr<uint8_t>(weight); + __nv_fp8_e8m0* scale_ptr = reinterpret_cast<__nv_fp8_e8m0*>(get_ptr<uint8_t>(scale)); + + int const n = weight.size(0); + int const k = weight.size(1) * 2; + + Tensor dequant_weight = torch::empty({n, k}, torch::dtype(torch::kFloat).device(torch::kCPU).requires_grad(false)); + float* dequant_weight_ptr = get_ptr<float>(dequant_weight); + + float fp4_lut[] = {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0}; + + for (int packed_idx = 0; packed_idx < weight.numel(); ++packed_idx) + { + int8_t weight_packed_data = weight_packed_ptr[packed_idx]; + + uint8_t weight_low_ = weight_packed_data & 0xF; + uint8_t weight_high_ = (weight_packed_data & 0xF0) >> 4; + + float weight_low = fp4_lut[weight_low_]; + float weight_high = fp4_lut[weight_high_]; + + int scale_n_idx = packed_idx / (k / 2); + int scale_k_idx = ((packed_idx * 2) % k) / group_size; + + float scale_ = static_cast<float>(scale_ptr[scale_n_idx * scale.size(1) + scale_k_idx]); + + dequant_weight_ptr[2 * packed_idx] = weight_low * scale_; + dequant_weight_ptr[2 * packed_idx + 1] = weight_high * scale_; + } + + return dequant_weight; +} + } // namespace torch_ext // Utility methods that may be useful for preprocessing weights in torch. @@ -380,3 +429,6 @@ static auto permute_B_rows_for_mixed_gemm = torch::RegisterOperators("trtllm::_permute_B_rows_for_mixed_gemm", &torch_ext::permute_B_rows_for_mixed_gemm); static auto subbyte_transpose = torch::RegisterOperators("trtllm::_subbyte_transpose", &torch_ext::subbyte_transpose); + +static auto mxfp4_dequantize_unswizzled + = torch::RegisterOperators("trtllm::mxfp4_dequantize_unswizzled", &torch_ext::mxfp4_dequantize_unswizzled); diff --git a/cpp/tests/runtime/gptDecoderBatchedTest.cpp b/cpp/tests/runtime/gptDecoderBatchedTest.cpp index 7c152f48a9..338b974aa0 100644 --- a/cpp/tests/runtime/gptDecoderBatchedTest.cpp +++ b/cpp/tests/runtime/gptDecoderBatchedTest.cpp @@ -104,15 +104,14 @@ void newRequests(std::vector<std::shared_ptr<tb::LlmRequest>> const& requests, T SizeType32 maxSequenceLength, tb::DecoderInputBuffers& inputBuffers, decoder::DecoderState& decoderState) { auto const& decoderStream = *decoder.getDecoderStream(); - auto const bufferManager = BufferManager{std::make_shared<CudaStream>(runtimeStream.get())}; auto batchSlotsRange = BufferRange<SizeType32>(*batchSlots); auto const localBatchSize = batchSlots->getSize(); tb::CreateNewDecoderRequests createNewDecoderRequests(false, false, false); - auto [lookaheadPrompt, lookaheadAlgoConfigs] = createNewDecoderRequests.createDecoderRequests(requests, - inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig, - runtimeStream, decoderStream, maxSequenceLength, std::nullopt); + auto [lookaheadPrompt, lookaheadAlgoConfigs] + = createNewDecoderRequests.createDecoderRequests(requests, inputBuffers.inputsIds, decodingConfig, decoderState, + logitsType, modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, std::nullopt); std::vector<SamplingConfig> samplingConfigs; samplingConfigs.reserve(requests.size()); diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index bfc62acc3f..8e58ee77f4 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -3053,189 +3053,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerVariableWindowAttentionWithReuseTest) assertBlocks(seq3, {4}, {6}); } -namespace -{ -KVCacheManager setupKvCacheManagerForHashTest(bool enableBlockReuse) -{ - auto constexpr numLayers = 2; - auto constexpr numHeads = 2; - auto constexpr sizePerHead = 64; - auto constexpr tokensPerBlock = 4; - auto constexpr maxNumSequences = 8; - auto constexpr maxBeamWidth = 1; - auto constexpr sinkTokenLength = 0; - auto const stream = std::make_shared<tr::CudaStream>(); - - auto constexpr maxBlocksPerSeq = 8; - auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq; - auto constexpr maxAttentionWindow = maxNumTokens; - - auto constexpr blocksInPrimaryPool = 16; - auto constexpr blocksInSecondaryPool = 0; - - auto constexpr onboardBlocks = true; - - auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; - - return KVCacheManager(std::vector<SizeType32>(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow, - maxNumSequences, maxBeamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, - CacheType::kSELF, std::nullopt, nullptr, - /*enableHashKey*/ true); -} - -std::vector<size_t> getHashAndRetrieveBlocksByHashTest( - BlockManager const& blockManager, std::vector<KVCacheBlock::IdType> const& blockIds, SizeType32 windowSize) -{ - std::vector<size_t> blockHashes; - for (auto blockId : blockIds) - { - blockHashes.emplace_back(blockManager.getBlockById(blockId, windowSize)->getHash()); - } - std::vector<BlockPtr> blockPtrs; - for (auto hash : blockHashes) - { - auto range = blockManager.getBlocksByHash(hash, windowSize); - BlockPtr const prevBlock = blockPtrs.empty() ? nullptr : blockPtrs.back(); - BlockPtr thisBlock = nullptr; - for (auto it = range.first; it != range.second; ++it) - { - if (it->second->getPrevBlockInSeq() == prevBlock) - { - thisBlock = it->second; - break; - } - } - EXPECT_NE(thisBlock, nullptr); - blockPtrs.emplace_back(thisBlock); - } - EXPECT_EQ(blockHashes.size(), blockPtrs.size()); - for (size_t i = 0; i < blockHashes.size(); i++) - { - EXPECT_EQ(blockManager.getBlockById(blockIds[i], windowSize), blockPtrs[i]); - } - return blockHashes; -} -} // namespace - -TEST_F(KVCacheManagerTest, KVCacheManagerHashKeyTest) -{ - auto kvCacheManager = setupKvCacheManagerForHashTest(false); - - auto const& blockManager = kvCacheManager.getBlockManager(); - - SizeType32 constexpr maxNewTokens = 4; - - // prepare tokens with token[i] = 1000 + i - TokenIdType constexpr firstToken = 1000; - - auto constexpr beamWidth = 1; - tr::SamplingConfig const samplingConfig{beamWidth}; - bool constexpr isStreaming{false}; - - SizeType32 requestId = 0; - int inputLength = 16; - auto inputTokens = std::make_shared<VecTokens>(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - auto constexpr beamIdx = 0; - - /////////////////////////////////////////////////////////////////////////// - // add a request and then remove it without reuse - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - - auto const onlyWindowSize = theOnlyWindowSize(kvCacheManager); - - auto& blockIds = seq.getCacheBlockIds(onlyWindowSize).at(beamIdx); - EXPECT_THAT(blockIds, ::testing::ElementsAreArray({0, 1, 2, 3})); - - // get blocks by hash and try to retrieve them by hash - auto blockHashes = getHashAndRetrieveBlocksByHashTest(blockManager, blockIds, onlyWindowSize); - - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - - // blocks are all removed - for (auto hash : blockHashes) - { - auto range = blockManager.getBlocksByHash(hash, onlyWindowSize); - EXPECT_EQ(range.first, range.second); - } - EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); -} - -TEST_F(KVCacheManagerTest, KVCacheManagerHashKeyWithReuseTest) -{ - auto kvCacheManager = setupKvCacheManagerForHashTest(true); - - auto const& blockManager = kvCacheManager.getBlockManager(); - - SizeType32 constexpr maxNewTokens = 4; - - // prepare tokens with token[i] = 1000 + i - TokenIdType constexpr firstToken = 1000; - - auto constexpr beamWidth = 1; - tr::SamplingConfig const samplingConfig{beamWidth}; - bool constexpr isStreaming{false}; - - SizeType32 requestId = 0; - int inputLength = 16; - auto inputTokens = std::make_shared<VecTokens>(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - auto constexpr beamIdx = 0; - - /////////////////////////////////////////////////////////////////////////// - // add a request and then remove it with reuse - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq0 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - - EXPECT_EQ(blockManager.getNumPools(), 1); - auto const onlyWindowSize = theOnlyWindowSize(kvCacheManager); - - auto& blockIds0 = seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx); - EXPECT_THAT(blockIds0, ::testing::ElementsAreArray({0, 1, 2, 3})); - - // get blocks by hash and try to retrieve them by hash - auto blockHashes = getHashAndRetrieveBlocksByHashTest(blockManager, blockIds0, onlyWindowSize); - - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - - // TODO: Make reused blocks accessible by hash, after sequence removed. Test here. - - /////////////////////////////////////////////////////////////////////////// - // add a new request with same prefix - requestId = 1; - inputLength = 20; - inputTokens->resize(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq1 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 15); - auto& blockIds1 = seq1.getCacheBlockIds(onlyWindowSize).at(beamIdx); - EXPECT_THAT(blockIds1, ::testing::ElementsAreArray({0, 1, 2, 3, 4})); - - std::ignore = getHashAndRetrieveBlocksByHashTest(blockManager, blockIds1, onlyWindowSize); - - // blocks are reused, so reused blocks are still accessible by previous hashes - for (size_t i = 0; i < 4; i++) - { - auto range = blockManager.getBlocksByHash(blockHashes[i], onlyWindowSize); - EXPECT_NE(range.first, range.second); - } - // evicted block is not accessible - { - size_t i = 4; - auto range = blockManager.getBlocksByHash(blockHashes[i], onlyWindowSize); - EXPECT_EQ(range.first, range.second); - } - EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 5); -} - TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) { auto constexpr numLayers = 12; diff --git a/cpp/tests/unit_tests/executor/loraConfigTest.cpp b/cpp/tests/unit_tests/executor/loraConfigTest.cpp index 2859739f6e..6ce56cccbd 100644 --- a/cpp/tests/unit_tests/executor/loraConfigTest.cpp +++ b/cpp/tests/unit_tests/executor/loraConfigTest.cpp @@ -53,13 +53,12 @@ TEST(LoraConfigTest, invalidInputs) // This should work auto loraConfig = LoraConfig(1, weights, config); + // Having config only without weights is allowed + loraConfig = LoraConfig(1, std::nullopt, config); { - // Only one specified - testInvalid(1, std::nullopt, config, "must have both"); - - // Only one specified - testInvalid(1, weights, std::nullopt, "must have both"); + // Only weights specified without config - not allowed + testInvalid(1, weights, std::nullopt, "lora weights must also have lora config"); } { diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 18f7e6f537..27fff8df7d 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -474,7 +474,7 @@ TEST(SerializeUtilsTest, VectorResponses) TEST(SerializeUtilsTest, KvCacheConfig) { texec::KvCacheConfig kvCacheConfig( - true, 10, std::vector(1, 100), 2, 0.1, 10000, false, 0.5, 50, 1024, false, false, true); + true, 10, std::vector(1, 100), 2, 0.1, 10000, false, 0.5, 50, 1024, false, false, true, 77); auto kvCacheConfig2 = serializeDeserialize(kvCacheConfig); EXPECT_EQ(kvCacheConfig.getEnableBlockReuse(), kvCacheConfig2.getEnableBlockReuse()); @@ -490,6 +490,7 @@ TEST(SerializeUtilsTest, KvCacheConfig) EXPECT_EQ(kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig2.getSecondaryOffloadMinPriority()); EXPECT_EQ(kvCacheConfig.getEventBufferMaxSize(), kvCacheConfig2.getEventBufferMaxSize()); EXPECT_EQ(kvCacheConfig.getUseUvm(), kvCacheConfig2.getUseUvm()); + EXPECT_EQ(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), kvCacheConfig2.getAttentionDpEventsGatherPeriodMs()); } TEST(SerializeUtilsTest, SchedulerConfig) @@ -846,6 +847,168 @@ TEST(SerializeUtilsTest, RequestStatsPerIteration) compareRequestStatsPerIteration(requestStatsPerIteration, requestStatsPerIteration2); } +void compareKvCacheEvents(texec::KVCacheEvent const& kvCacheEvent, texec::KVCacheEvent const& kvCacheEvent2) +{ + EXPECT_EQ(kvCacheEvent.eventId, kvCacheEvent2.eventId); + EXPECT_EQ(kvCacheEvent.windowSize, kvCacheEvent2.windowSize); + EXPECT_EQ(kvCacheEvent.attentionDpRank, kvCacheEvent2.attentionDpRank); + + if (std::holds_alternative<texec::KVCacheCreatedData>(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative<texec::KVCacheCreatedData>(kvCacheEvent2.data)); + auto data = std::get<texec::KVCacheCreatedData>(kvCacheEvent.data); + auto data2 = std::get<texec::KVCacheCreatedData>(kvCacheEvent2.data); + EXPECT_EQ(data.numBlocksPerCacheLevel, data2.numBlocksPerCacheLevel); + } + else if (std::holds_alternative<texec::KVCacheRemovedData>(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative<texec::KVCacheRemovedData>(kvCacheEvent2.data)); + auto data = std::get<texec::KVCacheRemovedData>(kvCacheEvent.data); + auto data2 = std::get<texec::KVCacheRemovedData>(kvCacheEvent2.data); + EXPECT_EQ(data.blockHashes, data2.blockHashes); + } + else if (std::holds_alternative<texec::KVCacheStoredData>(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative<texec::KVCacheStoredData>(kvCacheEvent2.data)); + auto data = std::get<texec::KVCacheStoredData>(kvCacheEvent.data); + auto data2 = std::get<texec::KVCacheStoredData>(kvCacheEvent2.data); + EXPECT_EQ(data.parentHash, data2.parentHash); + EXPECT_EQ(data.blocks.size(), data2.blocks.size()); + for (size_t i = 0; i < data.blocks.size(); ++i) + { + auto blockData = data.blocks[i]; + auto blockData2 = data2.blocks[i]; + EXPECT_EQ(blockData.blockHash, blockData2.blockHash); + EXPECT_EQ(blockData.loraId, blockData2.loraId); + EXPECT_EQ(blockData.cacheLevel, blockData2.cacheLevel); + EXPECT_EQ(blockData.priority, blockData2.priority); + EXPECT_EQ(blockData.tokens.size(), blockData2.tokens.size()); + for (size_t j = 0; j < blockData.tokens.size(); ++j) + { + EXPECT_EQ(blockData.tokens[j].tokenId, blockData2.tokens[j].tokenId); + EXPECT_EQ(blockData.tokens[j].tokenExtraId, blockData2.tokens[j].tokenExtraId); + } + } + } + else if (std::holds_alternative<texec::KVCacheUpdatedData>(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative<texec::KVCacheUpdatedData>(kvCacheEvent2.data)); + auto data = std::get<texec::KVCacheUpdatedData>(kvCacheEvent.data); + auto data2 = std::get<texec::KVCacheUpdatedData>(kvCacheEvent2.data); + EXPECT_EQ(data.blockHash, data2.blockHash); + if (data.cacheLevel) + { + EXPECT_TRUE(data2.cacheLevel); + EXPECT_EQ(data.cacheLevel.value().oldValue, data2.cacheLevel.value().oldValue); + EXPECT_EQ(data.cacheLevel.value().newValue, data2.cacheLevel.value().newValue); + } + if (data.priority) + { + EXPECT_TRUE(data2.priority); + EXPECT_EQ(data.priority.value().oldValue, data2.priority.value().oldValue); + EXPECT_EQ(data.priority.value().newValue, data2.priority.value().newValue); + } + } + else + { + FAIL() << "Unknown KVCacheEvent data type"; + } +} + +TEST(SerializeUtilsTest, KvCacheEventsDeque) +{ + // Created event + texec::KVCacheCreatedData kvCacheCreatedData{{1, 2}}; + texec::KVCacheEvent kvCacheCreatedEvent(1, kvCacheCreatedData, 32); + + // Removed event + texec::KVCacheEvent kvCacheRemovedEvent(1, texec::KVCacheRemovedData{{3, 4}}, 32); + + // Stored event + auto storedBlockData1 = texec::KVCacheStoredBlockData(77, {{1, 2}, {3, 4}, {5, 6}}, 88, 0, 99); + auto storedBlockData2 = texec::KVCacheStoredBlockData(99, {{11, 12}, {3, 4}, {15, 6}}, 77, 1, 101); + texec::KVCacheStoredData kvCacheStoredData{177, {storedBlockData1, storedBlockData2}}; + texec::KVCacheEvent kvCacheStoredEvent(1, kvCacheStoredData, 32); + + // Updated event + texec::KVCacheEventDiff<texec::SizeType32> diff{0, 1}; + texec::KVCacheEventDiff<texec::SizeType32> diff2{90, 99}; + texec::KVCacheUpdatedData kvCacheUpdatedData(999, diff, diff2); + texec::KVCacheEvent kvCacheEvent(1, kvCacheUpdatedData, 32); + + std::deque<texec::KVCacheEvent> kvCacheEvents{ + kvCacheCreatedEvent, kvCacheRemovedEvent, kvCacheStoredEvent, kvCacheEvent}; + + auto serializedEvents = texec::Serialization::serialize(kvCacheEvents); + auto kvCacheEvents2 = texec::Serialization::deserializeKVCacheEvents(serializedEvents); + + EXPECT_EQ(kvCacheEvents.size(), kvCacheEvents2.size()); + for (size_t i = 0; i < kvCacheEvents.size(); ++i) + { + compareKvCacheEvents(kvCacheEvents[i], kvCacheEvents2[i]); + } +} + +// Test for KVCacheEvent with KVCacheCreatedData +TEST(SerializeUtilsTest, KVCacheCreatedEvent) +{ + texec::KVCacheCreatedData kvCacheCreatedData{{1, 2}}; + texec::KVCacheEvent kvCacheEvent(1, kvCacheCreatedData, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +// Test for KVCacheEvent with KVCacheRemovedData +TEST(SerializeUtilsTest, KVCacheRemovedEvents) +{ + texec::KVCacheEvent kvCacheEvent(1, texec::KVCacheRemovedData{{3, 4}}, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +// Test for KVCacheEvent with KVCacheStoredData +TEST(SerializeUtilsTest, KVCacheStoredEvent) +{ + auto storedBlockData1 = texec::KVCacheStoredBlockData(77, {{1, 2}, {3, 4}, {5, 6}}, 88, 0, 99); + auto storedBlockData2 = texec::KVCacheStoredBlockData(99, {{11, 12}, {3, 4}, {15, 6}}, 77, 1, 101); + + texec::KVCacheStoredData kvCacheStoredData{177, {storedBlockData1, storedBlockData2}}; + texec::KVCacheEvent kvCacheEvent(1, kvCacheStoredData, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +// Test for KVCacheEvent with KVCacheUpdatedData +TEST(SerializeUtilsTest, KVCacheUpdatedEvent) +{ + texec::KVCacheEventDiff<texec::SizeType32> diff{0, 1}; + texec::KVCacheEventDiff<texec::SizeType32> diff2{90, 99}; + texec::KVCacheUpdatedData kvCacheUpdatedData(999, diff, diff2); + texec::KVCacheEvent kvCacheEvent(1, kvCacheUpdatedData, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +TEST(SerializeUtilsTest, UniqueToken) +{ + tensorrt_llm::runtime::UniqueToken token{1, 2}; + auto token2 = serializeDeserialize(token); + EXPECT_EQ(token.tokenId, token2.tokenId); + EXPECT_EQ(token.tokenExtraId, token2.tokenExtraId); +} + +TEST(SerializeUtilsTest, UniqueTokenVector) +{ + std::vector<tensorrt_llm::runtime::UniqueToken> tokens{{1, 2}, {3, 4}, {5, 6}}; + auto tokens2 = serializeDeserialize(tokens); + EXPECT_EQ(tokens.size(), tokens2.size()); + for (size_t i = 0; i < tokens.size(); ++i) + { + EXPECT_EQ(tokens[i].tokenId, tokens2[i].tokenId); + EXPECT_EQ(tokens[i].tokenExtraId, tokens2[i].tokenExtraId); + } +} + TEST(SerializeUtilsTest, MethodReturnType) { struct S diff --git a/cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu b/cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu index 3447efb62c..80c6aee4fe 100644 --- a/cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu +++ b/cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu @@ -419,9 +419,9 @@ public: CudaBuffer ref_scale(scale_out_size); // Here, we also only compare the accuracy of quantization. Since there are no differences in // computation order, atol is set to 0. - invokeFP4Quantization(token_num, hidden_dim, m_norm_out.device_data<DType>(), + invokeFP4Quantization(1, token_num, hidden_dim, m_norm_out.device_data<DType>(), m_scale_factor.device_data<float>(), ref_output.device_data<int64_t>(), - ref_scale.device_data<int32_t>(), false, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED, 128, 0); + ref_scale.device_data<int32_t>(), false, tensorrt_llm::QuantizationSFLayout::SWIZZLED, 128, 0); TLLM_CHECK(compare<int8_t>( m_rank, m_quant_out.host_data(), ref_output.host_data(), message_size / 2, "fp4 quant out", 0)); TLLM_CHECK(compare<int8_t>( @@ -460,9 +460,9 @@ public: void run_fp4_quant(int token_num, int hidden_dim) { - invokeFP4Quantization(token_num, hidden_dim, m_norm_out.device_data<DType>(), + invokeFP4Quantization(1, token_num, hidden_dim, m_norm_out.device_data<DType>(), m_scale_factor.device_data<float>(), m_quant_out.device_data<int64_t>(), m_scale_out.device_data<int32_t>(), - false, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED, 128, m_stream->get()); + false, tensorrt_llm::QuantizationSFLayout::SWIZZLED, 128, m_stream->get()); } void run_kernel(int token_num, int hidden_dim) diff --git a/cpp/tests/unit_tests/kernels/allReduce/moeAllReduceFusionTest.cu b/cpp/tests/unit_tests/kernels/allReduce/moeAllReduceFusionTest.cu index ef7712eff0..576d881f4f 100644 --- a/cpp/tests/unit_tests/kernels/allReduce/moeAllReduceFusionTest.cu +++ b/cpp/tests/unit_tests/kernels/allReduce/moeAllReduceFusionTest.cu @@ -544,7 +544,7 @@ public: // * Quant invokeFP4Quantization(token_num, hidden_dim, m_norm_out.device_data<DType>(), m_scale_factor.device_data<float>(), ref_output.device_data<int64_t>(), ref_scale.device_data<int32_t>(), - false, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED, 128, 0); + false, tensorrt_llm::QuantizationSFLayout::SWIZZLED, 128, 0); compare<int8_t>(m_rank, m_quant_out.host_data(), ref_output.host_data(), message_size / 2, 1e-3, "quant out"); compare<int8_t>(m_rank, m_scale_out.host_data(), ref_scale.host_data(), message_size / 16, 1e-3, "scale out"); } @@ -584,7 +584,7 @@ public: m_scale_factor.device_data<float>(), // input sf m_quant_out.device_data<int64_t>(), // output m_scale_out.device_data<int32_t>(), // output sf - false, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED, 128, m_stream->get()); + false, tensorrt_llm::QuantizationSFLayout::SWIZZLED, 128, m_stream->get()); } void run_kernel(int token_num, int hidden_dim) diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index c9e4a065eb..6f2ce0f93e 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h" @@ -9,15 +25,16 @@ #ifdef USING_OSS_CUTLASS_MOE_GEMM #include "tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h" +#include <tensorrt_llm/kernels/quantization.h> #else #include "moe_kernels.h" +#include "quantization.h" #endif #include "tensorrt_llm/kernels/cutlass_kernels/include/cutlass_kernel_selector.h" #include "tensorrt_llm/runtime/bufferManager.h" #include <tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h> -#include <tensorrt_llm/kernels/quantization.h> using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::common; @@ -27,6 +44,7 @@ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner; using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; +using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams; using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation; constexpr static float FP8_MAX = 448.f; @@ -128,6 +146,8 @@ protected: using OutputType = typename TypeTuple_::OutputType; using ActivationScale = typename TypeTuple_::ActivationScale; using WeightScale = typename TypeTuple_::WeightScale; + + using BackBoneType = OutputType; constexpr static bool INT4 = std::is_same_v<WeightType, cutlass::uint4b_t>; constexpr static bool ACT_FP8 = std::is_same_v<GemmDataType, SafeFP8>; constexpr static bool WEIGHT_FP8 = std::is_same_v<WeightType, SafeFP8>; @@ -245,11 +265,11 @@ protected: initWeightsKernel<WeightRawType><<<grid, block, 0, mStream->get()>>>(buffer, w, h, base, scalar); } - void initBias(DataType* buffer, int64_t w) + void initBias(BackBoneType* buffer, int64_t w) { dim3 block(256, 1, 1); dim3 grid(divUp(w, block.x), mNumExperts); - initBiasToExpertIdKernel<DataType><<<grid, block, 0, mStream->get()>>>(buffer, w); + initBiasToExpertIdKernel<BackBoneType><<<grid, block, 0, mStream->get()>>>(buffer, w); } void initWeightsGated(WeightRawType* buffer, int64_t w, int64_t h, float base_1, float base_2, float scalar) @@ -263,7 +283,7 @@ protected: initWeightsGatedKernel<WeightRawType><<<grid, block, 0, mStream->get()>>>(buffer, w, h, base_1, base_2, scalar); } - void initBiasGated(DataType* buffer, int64_t w) + void initBiasGated(BackBoneType* buffer, int64_t w) { if (!mIsGated) return initBias(buffer, w); @@ -271,10 +291,10 @@ protected: w /= 2; dim3 block(256, 1, 1); dim3 grid(divUp(w, block.x), mNumExperts); - initBiasToExpertIdGatedKernel<DataType><<<grid, block, 0, mStream->get()>>>(buffer, w); + initBiasToExpertIdGatedKernel<BackBoneType><<<grid, block, 0, mStream->get()>>>(buffer, w); } - CutlassMoeFCRunner<GemmDataType, WeightType, OutputType, InputType> mMoERunner{}; + CutlassMoeFCRunner<GemmDataType, WeightType, OutputType, InputType, BackBoneType> mMoERunner{}; char* mWorkspace{}; int* mSelectedExpert; float* mTokenFinalScales{}; @@ -282,6 +302,14 @@ protected: WeightRawType* mRawExpertWeight2{}; WeightStorage* mExpertWeight1{}; WeightStorage* mExpertWeight2{}; + + float mSwigluAlphaValue{0.5f}; + float mSwigluBetaValue{MX_QUANT_ACT ? 0.0f : 1.f}; + float mSwigluLimitValue{MX_QUANT_ACT ? FP8_MAX / 4 : NVFP4 ? 2.f : 0.5f}; + float* mSwigluAlpha{}; + float* mSwigluBeta{}; + float* mSwigluLimit{}; + DataType* mExpertIntScale1{}; DataType* mExpertIntScale2{}; @@ -314,8 +342,8 @@ protected: ElementSF* mFP4ScalingFactorsW1 = nullptr; ElementSF* mFP4ScalingFactorsW2 = nullptr; - DataType* mExpertBias1{}; - DataType* mExpertBias2{}; + BackBoneType* mExpertBias1{}; + BackBoneType* mExpertBias2{}; void* mTpExpertScratch{}; // Copy the experts here when slicing up inputs size_t mTpExpertScratchSize{}; @@ -343,7 +371,7 @@ protected: float mSparseMixerEpsilon = 0.2f; // Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths - bool mUseDeterminsiticHopperReduce = true; + bool mUseDeterministicHopperReduce = true; // Disable this for long running tests to speed up runtime bool mIsLongTest = false; @@ -428,7 +456,7 @@ protected: { managed_buffers.clear(); - mMoERunner.use_deterministic_hopper_reduce_ = k > 2 && mUseDeterminsiticHopperReduce; + mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce; mHiddenSize = hidden_size; mInterSize = hidden_size * mInterSizeFraction; @@ -467,12 +495,14 @@ protected: if (mUseBias) { // Allow space for the slice of bias1 in the scratch - mTpExpertScratchSize += sizeof(DataType) * experts_per_node * gated_inter / parallelism_config.tp_size; - mExpertBias1 = allocBuffer<DataType>(mNumExperts * gated_inter); - mExpertBias2 = allocBuffer<DataType>(mNumExperts * mHiddenSize); + mTpExpertScratchSize += sizeof(BackBoneType) * experts_per_node * gated_inter / parallelism_config.tp_size; + mExpertBias1 = allocBuffer<BackBoneType>(mNumExperts * gated_inter); + mExpertBias2 = allocBuffer<BackBoneType>(mNumExperts * mHiddenSize); - check_cuda_error(cudaMemsetAsync(mExpertBias1, 0x0, mNumExperts * gated_inter * sizeof(DataType), stream)); - check_cuda_error(cudaMemsetAsync(mExpertBias2, 0x0, mNumExperts * mHiddenSize * sizeof(DataType), stream)); + check_cuda_error( + cudaMemsetAsync(mExpertBias1, 0x0, mNumExperts * gated_inter * sizeof(BackBoneType), stream)); + check_cuda_error( + cudaMemsetAsync(mExpertBias2, 0x0, mNumExperts * mHiddenSize * sizeof(BackBoneType), stream)); } if constexpr (INT_QUANT) @@ -539,6 +569,22 @@ protected: mSourceToExpandedMap = allocBuffer<int>(mTotalTokens * mK); + if (mActType == ActivationType::SwigluBias) + { + mSwigluAlpha = allocBuffer<float>(mNumExperts); + mSwigluBeta = allocBuffer<float>(mNumExperts); + mSwigluLimit = allocBuffer<float>(mNumExperts); + std::vector<float> h_swiglu_alpha(mNumExperts, mSwigluAlphaValue); + std::vector<float> h_swiglu_beta(mNumExperts, mSwigluBetaValue); + std::vector<float> h_swiglu_limit(mNumExperts, mSwigluLimitValue); + check_cuda_error(cudaMemcpyAsync( + mSwigluAlpha, h_swiglu_alpha.data(), mNumExperts * sizeof(float), cudaMemcpyHostToDevice, stream)); + check_cuda_error(cudaMemcpyAsync( + mSwigluBeta, h_swiglu_beta.data(), mNumExperts * sizeof(float), cudaMemcpyHostToDevice, stream)); + check_cuda_error(cudaMemcpyAsync( + mSwigluLimit, h_swiglu_limit.data(), mNumExperts * sizeof(float), cudaMemcpyHostToDevice, stream)); + } + check_cuda_error(cudaMemcpyAsync(mSelectedExpert, h_token_selected_experts.data(), mTotalTokens * mK * sizeof(int), cudaMemcpyHostToDevice, stream)); check_cuda_error(cudaMemcpyAsync(mTokenFinalScales, h_token_final_scales.data(), @@ -606,9 +652,15 @@ protected: int64_t padded_in_dim = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(in_shape, MinKDimAlignmentFP4); check_cuda_error(cudaMemsetAsync(scaling_factors, 0x00, num_experts * padded_out_dim * padded_in_dim / FP4VecSize * sizeof(ElementSF), mStream->get())); +#ifdef USING_OSS_CUTLASS_MOE_GEMM + invokeFP4Quantization<WeightRawType, FP4VecSize>(num_experts, out_shape, in_shape, raw_weights, global_scales, + reinterpret_cast<int64_t*>(quant_weights), reinterpret_cast<int32_t*>(scaling_factors), MX_QUANT_WEIGHT, + tensorrt_llm::QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, mStream->get()); +#else invokeBatchedFP4Quantization<WeightRawType, FP4VecSize>(num_experts, out_shape, in_shape, raw_weights, global_scales, reinterpret_cast<int64_t*>(quant_weights), reinterpret_cast<int32_t*>(scaling_factors), MX_QUANT_WEIGHT, mMultiProcessorCount, mStream->get()); +#endif // auto sf_data = getDataFromDevice<ElementSF>(scaling_factors, num_experts * padded_out_dim * padded_in_dim / // FP4VecSize); auto unquant_data = getDataFromDevice<WeightRawType>(raw_weights, num_experts * out_shape * @@ -873,7 +925,8 @@ protected: // Generates numbers in increments of 1/max_order_of_magnitude in the range [0, 1) constexpr int max_order_of_magnitude = 256; std::vector<int> base(hidden_states.size()); - std::iota(base.begin(), base.end(), 0); + // Start from the near largest value so we always have some large values even for small hidden sizes + std::iota(base.begin(), base.end(), max_order_of_magnitude - 4); std::mt19937 gen(0xD5); std::shuffle(base.begin(), base.end(), gen); // Lambda subtracts a small value so we have some < 0 to test the activation for negatives @@ -998,7 +1051,7 @@ protected: auto* weight_1 = reinterpret_cast<SliceWeightType*>(mTpExpertScratch); auto* weight_2 = weight_1 + experts_per_node * gated_matrix_size / SLICED_WEIGHT_ELEM_PER_BYTE; auto* bias_1 - = reinterpret_cast<DataType*>(weight_2 + experts_per_node * matrix_size / SLICED_WEIGHT_ELEM_PER_BYTE); + = reinterpret_cast<BackBoneType*>(weight_2 + experts_per_node * matrix_size / SLICED_WEIGHT_ELEM_PER_BYTE); // 2D memcpy just the slices we care about // TODO Re-quantize here with matrices divided @@ -1014,7 +1067,7 @@ protected: if (mUseBias) { - size_t const row_size_bias = row_size_inter * sizeof(DataType); + size_t const row_size_bias = row_size_inter * sizeof(BackBoneType); check_cuda_error(cudaMemcpy2DAsync(bias_1, row_size_bias, (uint8_t*) bias1_ptr + row_size_bias * tp_rank, row_size_bias * tp_size, row_size_bias, experts_per_node * mGatedMultiplier, cudaMemcpyDeviceToDevice, mStream->get())); @@ -1173,15 +1226,17 @@ protected: MoeMinLatencyParams min_latency_params; mMoERunner.setTactic(tactic1, tactic2); #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mActType, - weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, - mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, enable_alltoall, - mUseLora, lora_params, useFp8BlockScales, minLatencyMode, min_latency_params, stream); -#else - mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mActType, - weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, - mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params, + mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, + ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params, + mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, + mFinalOutput, mSourceToExpandedMap, parallelism_config, enable_alltoall, mUseLora, lora_params, useFp8BlockScales, minLatencyMode, min_latency_params, stream); +#else + mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, + ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params, + mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, + mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params, useFp8BlockScales, + minLatencyMode, min_latency_params, stream); #endif check_cuda_error(cudaStreamSynchronize(stream)); @@ -1255,20 +1310,27 @@ protected: } template <class T> - T actfn(T in) + T actfn(T gate, T linear = T(0.0f), ActivationType act_type = ActivationType::InvalidType) { - if (mActType == ActivationType::Identity) - return in; - if (mActType == ActivationType::Relu) - return std::max(in, T(0.0f)); - if (mActType == ActivationType::Gelu || mActType == ActivationType::Geglu) - return (std::erf(float(in) * float(sqrt(0.5))) + 1) * 0.5f * float(in); - if (mActType == ActivationType::Silu || mActType == ActivationType::Swiglu) + if (act_type == ActivationType::InvalidType) + act_type = mActType; + + switch (act_type) { - return (float(in) / (1.f + std::exp(-(in)))); + case ActivationType::Identity: return gate; + case ActivationType::Relu: return std::max(gate, T(0.0f)); + case ActivationType::Gelu: return ((std::erf(float(gate) * float(sqrt(0.5))) + 1) * 0.5f * float(gate)); + case ActivationType::Silu: return (float(gate) / (1.f + std::exp(-(gate)))); + case ActivationType::Geglu: return actfn(gate, 0.0f, ActivationType::Gelu) * linear; + case ActivationType::Swiglu: return actfn(gate, 0.0f, ActivationType::Silu) * linear; + case ActivationType::SwigluBias: + linear = std::min(std::max(linear, -mSwigluLimitValue), mSwigluLimitValue); + gate = std::min(gate, mSwigluLimitValue); + // silu(gate * alpha) / alpha = gate * sigmoid(gate * alpha) + return actfn(gate * mSwigluAlphaValue, 0.0f, ActivationType::Silu) / mSwigluAlphaValue + * (linear + mSwigluBetaValue); + default: assert(false); return gate; } - assert(false); - return in; } float quantAct(float in, float block_max) @@ -1292,15 +1354,14 @@ protected: if (mIsGated) { float scalar = applyExpertShift(mExpertWDiag1, expert_id); - float fc1 = input * scalar + w1_bias; - + float linear = input * scalar + w1_bias; float gated_scalar = applyExpertShift(mExpertWDiagGated, expert_id); float gated_bias = mUseBias ? w1_bias + 1.f : 0.f; float gate = input * gated_scalar + gated_bias; - activated = fc1 * actfn(gate); + activated = actfn(gate, linear); - block_max = (block_max * scalar + w1_bias) * actfn(block_max * gated_scalar + gated_bias); + block_max = actfn(block_max * gated_scalar + gated_bias, block_max * scalar + w1_bias); } else { @@ -1390,6 +1451,12 @@ protected: void compareFinal(std::vector<int> const& expected_experts, std::vector<float> const& token_final_scales, std::vector<OutputType> const& input_data, std::vector<OutputType> final_results = {}) { + if (mActType == ActivationType::SwigluBias) + { + ASSERT_GT(mMaxInput * std::max(mExpertWDiag1, mExpertWDiagGated), mSwigluLimitValue) + << "SwigluBias limit values don't change the result"; + } + ASSERT_EQ(expected_experts.size(), token_final_scales.size()); ASSERT_EQ(expected_experts.size() / mK, input_data.size() / mHiddenSize); if (final_results.empty()) @@ -1533,7 +1600,7 @@ template <class TypeParam_> void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest( int k, int64_t hidden_size, int64_t num_experts, int64_t num_tokens) { - if constexpr (ANY_FPX) + if (NVFP4 || (MXFP8_MXFP4 && isGatedActivation(mActType))) { // TODO Remove this when bias + FPX is supported mUseBias = false; @@ -1563,7 +1630,7 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest( runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k); bool should_be_deterministic - = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1672,9 +1739,17 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwiglu) this->BasicPermuteTest(3); } +TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias) +{ + this->mActType = ActivationType::SwigluBias; + this->BasicPermuteTest(); + this->BasicPermuteTest(2); + this->BasicPermuteTest(3); +} + TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic) { - this->mUseDeterminsiticHopperReduce = false; + this->mUseDeterministicHopperReduce = false; // Just test case 3, cases 1&2 always use the fused paths this->BasicPermuteTest(3); } @@ -1776,7 +1851,7 @@ template <class TypeParam_> void MixtureOfExpertsTest<TypeParam_>::ParallelismTest( int k, int tp_size, int ep_size, int64_t hidden_size, int64_t num_experts, int64_t num_tokens, bool enable_alltoall) { - if (ANY_FPX) + if (NVFP4 || (MXFP8_MXFP4 && isGatedActivation(mActType))) { // TODO Remove this when bias + FPX is supported mUseBias = false; @@ -1822,7 +1897,7 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest( runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k, MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool should_be_deterministic - = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1838,7 +1913,7 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest( { runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool should_be_deterministic - = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1940,6 +2015,13 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest( this->ParallelismType##Test(); \ this->ParallelismType##Test(2); \ this->ParallelismType##Test(3); \ + } \ + TYPED_TEST(MixtureOfExpertsTest, ParallelismType##SwigluBias) \ + { \ + this->mActType = ActivationType::SwigluBias; \ + this->ParallelismType##Test(); \ + this->ParallelismType##Test(2); \ + this->ParallelismType##Test(3); \ } \ \ TYPED_TEST(MixtureOfExpertsTest, ParallelismType##Mixtral8x7b) \ @@ -2018,7 +2100,7 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) return tactic.str(); }; - auto activation_pool = std::vector{ActivationType::Relu, ActivationType::Swiglu, ActivationType::Geglu}; + auto activation_pool = std::vector{ActivationType::Relu, ActivationType::Swiglu, ActivationType::SwigluBias}; if (this->NVFP4) activation_pool = {ActivationType::Relu}; auto configs = this->getFilteredConfigs(getSMVersion()); @@ -2036,12 +2118,12 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) } ASSERT_NO_THROW({ this->mActType = activation_type; - for (int k = 1; k <= 3; k++) + for (auto k : {2, 3}) { this->mOverrideSelectedConfig1 = conf1; this->mOverrideSelectedConfig2 = conf2; - this->BasicPermuteTest(k); + this->BasicPermuteTest(k, this->MINIMUM_ALIGNMENT); if (::testing::Test::HasFailure()) // Throw on test failure so we get the print message throw std::runtime_error("Test k=" + std::to_string(k) + " Failed"); } @@ -2075,7 +2157,7 @@ TYPED_TEST(LargeMixtureOfExpertsTest, PermuteVeryLargeExperts) TYPED_TEST(LargeMixtureOfExpertsTest, PermuteVeryLongSequence) { this->mIsLongTest = true; - this->mUseBias = !this->ANY_FPX; + this->mUseBias = !this->NVFP4; using DataType = typename MixtureOfExpertsTest<TypeParam>::DataType; // Sequence * hidden size > INT32_MAX diff --git a/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu b/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu index ccb2a77fcc..18e1838533 100644 --- a/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu +++ b/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu @@ -132,7 +132,6 @@ void selfAttentionRef(T* output, T* const Q, T* const KV, int batch_size, int nu int global_q_offset = cu_seq_q_len[b] * num_heads * head_size; int global_kv_offset = cu_seq_kv_len[b] * 2 * num_heads * head_size; int global_softmax_offset = cu_seq_q_len[b] * num_heads * 2; - float bmm1_scale = 1.F / std::sqrt(static_cast<float>(head_size)); if (curr_q_len == 0 || curr_kv_len == 0) { continue; // skip empty sequences @@ -187,9 +186,8 @@ void selfAttentionRef(T* output, T* const Q, T* const KV, int batch_size, int nu float sum = 0; for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) { - // P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * bmm1_scale); - // hack for real mla kernel - P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * 0.072168784); + // P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv]); + P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv]); sum += P[s_q * curr_kv_len + s_kv]; } for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) diff --git a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp index 48c90eaff1..67cdd77ddc 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp @@ -178,6 +178,28 @@ private: TYPED_TEST_SUITE(RoutingRenormalizeKernelTest, FloatAndBf16Types); +TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization) +{ + RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, + /*numExperts=*/128, /*topK=*/8, + /*expertParallelization=*/1, /*expertParallelizationId=*/0, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + this->runTest(param); +}; + +TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertParallelization) +{ + RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/14, + /*numExperts=*/128, /*topK=*/8, + /*expertParallelization=*/2, /*expertParallelizationId=*/1, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + this->runTest(param); +}; + TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10, diff --git a/cpp/tests/unit_tests/kernels/smoothQuant/smoothQuantKernelTest.cpp b/cpp/tests/unit_tests/kernels/smoothQuant/smoothQuantKernelTest.cpp index d3779f5240..e6936705e7 100644 --- a/cpp/tests/unit_tests/kernels/smoothQuant/smoothQuantKernelTest.cpp +++ b/cpp/tests/unit_tests/kernels/smoothQuant/smoothQuantKernelTest.cpp @@ -286,13 +286,13 @@ TEST(Kernel, WeightOnly) std::vector<int> ks{2048, 4096}; std::vector<tensorrt_llm::common::QuantMode> quant_modes(4); quant_modes[0] = tensorrt_llm::common::QuantMode::fromDescription( - false, false, false, false, false, false, false, false, false, false, false, false, false, false); + false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false); quant_modes[1] = tensorrt_llm::common::QuantMode::fromDescription( - false, false, true, false, false, false, false, false, false, false, false, false, false, false); + false, false, true, false, false, false, false, false, false, false, false, false, false, false, false, false); quant_modes[2] = tensorrt_llm::common::QuantMode::fromDescription( - false, false, false, true, false, false, false, false, false, false, false, false, false, false); + false, false, false, true, false, false, false, false, false, false, false, false, false, false, false, false); quant_modes[3] = tensorrt_llm::common::QuantMode::fromDescription( - false, false, true, true, false, false, false, false, false, false, false, false, false, false); + false, false, true, true, false, false, false, false, false, false, false, false, false, false, false, false); for (auto m : ms) { for (auto n : ns) diff --git a/cpp/tests/unit_tests/runtime/decodingLayerWorkspaceTest.cpp b/cpp/tests/unit_tests/runtime/decodingLayerWorkspaceTest.cpp index 066ad5a8ca..bb6ce6410a 100644 --- a/cpp/tests/unit_tests/runtime/decodingLayerWorkspaceTest.cpp +++ b/cpp/tests/unit_tests/runtime/decodingLayerWorkspaceTest.cpp @@ -16,6 +16,7 @@ #include "tensorrt_llm/runtime/decodingLayerWorkspace.h" #include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/workspace.h" #include <gtest/gtest.h> #include <random> @@ -171,7 +172,7 @@ TEST_P(MirrorInWorkspaceTest, TestMirrorInWorkspaceFunctionality) requiredWorkspaceSize) << "The calculated workspace size cannot possibly be enough to contain all the tensors."; - constexpr std::size_t addressAlignment = 128; + constexpr std::size_t addressAlignment = tensorrt_llm::common::kCudaMemAlign; constexpr std::size_t numTensors = 3; constexpr std::size_t maxAlignmentOverhead = numTensors * addressAlignment; ASSERT_GE(hostTensor1->getSizeInBytes() + hostTensor2->getSizeInBytes() + hostTensor3->getSizeInBytes() diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index aba824f7ba..dcd4ca073b 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -71,8 +71,9 @@ RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh ENV PYTORCH_CUDA_ALLOC_CONF="garbage_collection_threshold:0.99999" # Install OpenCV with FFMPEG support -RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/ -RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir +RUN pip3 uninstall -y opencv && \ + rm -rf /usr/local/lib/python3*/dist-packages/cv2/ && \ + pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir # COPY bringup_fix.sh bringup_fix.sh # RUN bash ./bringup_fix.sh && rm bringup_fix.sh diff --git a/docker/Makefile b/docker/Makefile index b70763f09c..81adc8e532 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -1,6 +1,8 @@ # Default base image for the docker build as defined in Dockerfile.multi BASE_IMAGE ?= $(shell grep '^ARG BASE_IMAGE=' Dockerfile.multi | grep -o '=.*' | tr -d '="') BASE_TAG ?= $(shell grep '^ARG BASE_TAG=' Dockerfile.multi | grep -o '=.*' | tr -d '="') +TRITON_IMAGE ?= $(shell grep '^ARG TRITON_IMAGE=' Dockerfile.multi | grep -o '=.*' | tr -d '="') +TRITON_BASE_TAG ?= $(shell grep '^ARG TRITON_BASE_TAG=' Dockerfile.multi | grep -o '=.*' | tr -d '="') # Name of the new image IMAGE_NAME ?= tensorrt_llm IMAGE_TAG ?= latest @@ -80,6 +82,8 @@ endef --progress $(DOCKER_PROGRESS) \ $(if $(BASE_IMAGE), --build-arg BASE_IMAGE=$(BASE_IMAGE)) \ $(if $(BASE_TAG), --build-arg BASE_TAG=$(BASE_TAG)) \ + $(if $(TRITON_IMAGE), --build-arg TRITON_IMAGE=$(TRITON_IMAGE)) \ + $(if $(TRITON_BASE_TAG), --build-arg TRITON_BASE_TAG=$(TRITON_BASE_TAG)) \ $(if $(BUILD_WHEEL_ARGS), --build-arg BUILD_WHEEL_ARGS="$(BUILD_WHEEL_ARGS)") \ $(if $(BUILD_WHEEL_SCRIPT), --build-arg BUILD_WHEEL_SCRIPT="$(BUILD_WHEEL_SCRIPT)") \ $(if $(TORCH_INSTALL_TYPE), --build-arg TORCH_INSTALL_TYPE="$(TORCH_INSTALL_TYPE)") \ diff --git a/docs/source/advanced/disaggregated-service.md b/docs/source/advanced/disaggregated-service.md index e5c4a19ba4..a9955b940a 100644 --- a/docs/source/advanced/disaggregated-service.md +++ b/docs/source/advanced/disaggregated-service.md @@ -1,10 +1,10 @@ (disaggregated-service)= -# Disaggregated-Service (Experimental) +# Disaggregated-Service (Prototype) ```{note} Note: -This feature is currently experimental, and the related API is subjected to change in future versions. +This feature is currently in prototype, and the related API is subjected to change in future versions. ``` Currently TRT-LLM supports `disaggregated-service`, where the context and generation phases of a request can run on different executors. TRT-LLM's disaggregated service relies on the executor API, please make sure to read the [executor page](executor.md) before reading the document. @@ -66,17 +66,6 @@ A. Yes, it's recommended that different executor use different GPUs . We support ### Debugging FAQs -*Q. How to handle error `Disaggregated serving is not enabled, please check the configuration?`* - -A. please set `backendType` of `CacheTransceiverConfig`. -```cpp -ExecutorConfig executorConfig{...}; - -executorConfig.setCacheTransceiverConfig(texec::CacheTransceiverConfig(BackendType::DEFAULT)); -``` - -When the environment variable `TRTLLM_USE_MPI_KVCACHE=1` is set, TRT-LLM will transfer the KV cache using `CUDA-aware MPI`. All executor processes involved must share the same MPI world communicator. Consequently, with `TRTLLM_USE_MPI_KVCACHE=1`, TRT-LLM only supports launching multiple executors via `MPI`. Additionally, the `CommunicationMode` for the executors must be set to `kLEADER` or `kORCHESTRATOR` with `SpawnProcesses=false` for the `disaggregated-service`. These restrictions do not apply when `TRTLLM_USE_UCX_KVCACHE=1` is set. - *Q. Does TRT-LLM support using GPU direct RDMA for inter-node KV Cache transfer?* A. Yes, TRT-LLM supports using GPU direct RDMA for inter-node KV cache transfer. diff --git a/docs/source/advanced/expert-parallelism.md b/docs/source/advanced/expert-parallelism.md index 1d3d75540c..9541563be2 100644 --- a/docs/source/advanced/expert-parallelism.md +++ b/docs/source/advanced/expert-parallelism.md @@ -4,7 +4,7 @@ ## Mixture of Experts (MoE) -Mixture of Experts (MoE) architectures have been used widely recently, such as [Mistral Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json). Specifically, MOE’s structure supports multiple parallel Feedforward Neural Network (FFN) layers (called experts) to replace the single FFN layer in the dense model. When tokens arrive, the router layer selects the TopK experts for each token. The corresponding hidden state of the token is then dispatched to the selected TopK experts, respectively. As a result, there are multiple tokens’ hidden states that are dispatched to each expert. +Mixture of Experts (MoE) architectures have become widespread, with models such as [Mistral Mixtral 8×7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1). Specifically, MoE’s structure supports multiple parallel feed-forward neural-network (FFN) layers (called experts) in place of the single FFN layer in a dense model. When tokens arrive, the router layer selects the top-k experts for each token, and the corresponding hidden state of each token is dispatched to those experts. As a result, there are multiple tokens’ hidden states that are dispatched to each expert. <img src="https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/media/moe_structure.png?raw=true" alt="moe_structure" width="500" height="auto"> @@ -23,9 +23,8 @@ When both Tensor Parallel and Expert Parallel are enabled, each GPU handles a po ## How to Enable -The default parallel pattern is Tensor Parallel. You can enable Expert Parallel or hybrid parallel by setting `--moe_tp_size` and `--moe_ep_size` when calling `convert_coneckpoint.py`. If only `--moe_tp_size` is provided, TRT-LLM will use Tensor Parallel for the MoE model; if only `--moe_ep_size` is provided, TRT-LLM will use Expert Parallel; if both are provided, the hybrid parallel will be used. +The default parallel pattern is Tensor Parallel. You can enable Expert Parallel or hybrid parallel by setting `--moe_tp_size` and `--moe_ep_size` when calling `convert_checkpoint.py`. If only `--moe_tp_size` is provided, TRT-LLM will use Tensor Parallel for the MoE model; if only `--moe_ep_size` is provided, TRT-LLM will use Expert Parallel; if both are provided, the hybrid parallel will be used. Ensure the product of `moe_tp_size` and `moe_ep_size` is equal to `tp_size`, since the total number of MoE parallelism across all GPUs must match the total number of parallelism in other parts of the model. The other parameters related to the MoE structure, such as `num_experts_per_tok` (TopK in previous context) and `num_local_experts,` can be found in the model’s configuration file, such as the one for [Mixtral 8x7B model](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json). -) diff --git a/docs/source/advanced/gpt-attention.md b/docs/source/advanced/gpt-attention.md index 9fa1ae9b43..760637aed4 100644 --- a/docs/source/advanced/gpt-attention.md +++ b/docs/source/advanced/gpt-attention.md @@ -112,8 +112,6 @@ printed. #### XQA Optimization Another optimization for MQA/GQA in generation phase called XQA optimization. -It is still experimental feature and support limited configurations. LLAMA2 70B -is one model that it supports. Support matrix of the XQA optimization: - FP16 / BF16 compute data type. diff --git a/docs/source/advanced/speculative-decoding.md b/docs/source/advanced/speculative-decoding.md index 85a87ae062..c6975a423c 100644 --- a/docs/source/advanced/speculative-decoding.md +++ b/docs/source/advanced/speculative-decoding.md @@ -60,7 +60,8 @@ These tokens are then forwarded to the Target model for verification. Upon verification, the Target model may return up to `K+1` tokens. Subsequently, the prompt, now updated with the accepted tokens, is sent back to the Draft model to initiate the generation of new draft tokens. This iterative process continues until a predefined stop conditions are met. -An example of this orchestration process can be found in the [TensorRT-LLM Triton backend](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/inflight_batcher_llm/client/e2e_grpc_speculative_decoding_client.py). +An example orchestration script is available in the Triton backend repository’s +[draft-target-model client example](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/client/python/draft_target_model_client.py). We provide two styles of running Draft-Target-Model now: using TensorRT-LLM-BLS in Triton Inference Server, or using TensorRT-LLM directly. Detailed steps of running can be found in [examples/draft_target_model/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/draft_target_model/README.md) and the code can be found in [examples/ngram/run_dtm_ngram.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/ngram/run_dtm_ngram.py). @@ -168,11 +169,11 @@ TensorRT-LLM implements the ReDrafter model such that logits prediction, beam se The EAGLE approach enhances the single-model Medusa method by predicting and verifying tokens using the same model. Similarly to ReDrafter, it predicts draft tokens using a recurrent predictor where each draft token depends on the previous one. However, unlike ReDrafter, it uses a single-layer transformer model to predict draft tokens from previous hidden states and decoded tokens. In the EAGLE-1 decoding tree needs to be known during the decoding. In the EAGLE-2 this tree is asssembled during the execution by searching for the most probable hypothesis along the beam. -Similarly to ReDrafter, TensorRT-LLM implements the EAGLE model such that logits prediction, draft tokens acceptance and draft token generation are performed inside of the TensorRT engine. EAGLE-1 and EAGLE-2 are both supported, while EAGLE-2 is currently in the experimental stage. Please, visit the [EAGLE README](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eagle/README.md) for information about building and running the model. +Similarly to ReDrafter, TensorRT-LLM implements the EAGLE model such that logits prediction, draft tokens acceptance and draft token generation are performed inside of the TensorRT engine(EAGLE-1 and EAGLE-2 are both supported). Please, visit the [EAGLE README](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eagle/README.md) for information about building and running the model. ### Disaggregated Serving -[Disaggregated Serving](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/disaggregated-service.md) with EAGLE3 using the two model approach is supported in the Pytorch backend. Please refer to the following [Dynamo example](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/llama4_plus_eagle.md) on how to run EAGLE3 with Disaggregated Serving for Llama 4 Maverick. +[Disaggregated Serving](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/disaggregated-service.md) with EAGLE-3 using the two-model approach is supported in the PyTorch backend. ## Lookahead Decoding diff --git a/docs/source/architecture/model-weights-loader.md b/docs/source/architecture/model-weights-loader.md index eb393d4a7d..361c385349 100644 --- a/docs/source/architecture/model-weights-loader.md +++ b/docs/source/architecture/model-weights-loader.md @@ -249,7 +249,7 @@ for tllm_key, param in tqdm(trtllm_model.named_parameters()): In this mode, every precision require user's own support. ## Trouble shooting -The weights loader is an experimental feature for now, and is enabled for LLaMA family models and Qwen models by default. +The weights loader is enabled for LLaMA family models and Qwen models by default with TensorRT flow only. If users are encountered with failure caused by `ModelWeightsLoader`, a workaround is passing environmental variable `TRTLLM_DISABLE_UNIFIED_CONVERTER=1` to disable the model weights loader and fallback to the legacy path. diff --git a/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md b/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md index 05d18284a0..d3a115ef14 100644 --- a/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md +++ b/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md @@ -412,9 +412,10 @@ Generally, you should make sure that `max_batch_size` is not too low to bottlene For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md). -### Not supported: MLA chunked context support on Hopper +### MLA chunked context + +MLA currently supports the chunked context feature on both Hopper and Blackwell GPUs. You can use `--enable_chunked_context` to enable it. This feature is primarily designed to reduce TPOT (Time Per Output Token). The default chunk size is set to `max_num_tokens`. If you want to achieve a lower TPOT, you can appropriately reduce the chunk size. However, please note that this will also decrease overall throughput. Therefore, a trade-off needs to be considered. -MLA chunked context support has been added on Blackwell GPUs, while it's not supported on Hopper yet. On Hopper, note that `max_num_tokens` has to be at least larger than the max input sequence length of the samples in dataset. For more details on `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md). ### Out of memory issues diff --git a/docs/source/blogs/Falcon180B-H200.md b/docs/source/blogs/Falcon180B-H200.md index f2c2fe7592..f9c7f760f1 100644 --- a/docs/source/blogs/Falcon180B-H200.md +++ b/docs/source/blogs/Falcon180B-H200.md @@ -33,7 +33,7 @@ Often quantization can have adverse impacts on the accuracy of the model, however, TensorRT-LLM's AWQ decreases memory footprint of the model by **4x** while maintaining high accuracy. -<img src="https://github.com/NVIDIA/TensorRT-LLM/blob/rel/docs/source/blogs/media/Falcon180B-H200_acc.png?raw=true" alt="Falcon-180B accuracy comparison" width="600" height="auto"> +<img src="https://github.com/NVIDIA/TensorRT-LLM/blob/5aec7af45fc0abd876fa68a9ae8c8cae084f3af3/docs/source/blogs/media/Falcon180B-H200_acc.png?raw=true" alt="Falcon-180B accuracy comparison" width="600" height="auto"> <sup>Preliminary measured accuracy, subject to change. </sup> diff --git a/docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md b/docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md index 201c3781a8..48f6728eab 100644 --- a/docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md +++ b/docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md @@ -125,7 +125,7 @@ The modules in the diagram are: | Baseline: CUDA Graph + EP8TP8 | 67 | [modeling_deepseekv3.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/models/modeling_deepseekv3.py) | | Multi Stream to overlap shared expert with sparse experts | 73 | [modeling_deepseekv3.py#L506](https://github.com/NVIDIA/TensorRT-LLM/blob/14bfb5e0d6e81aec3306a1324cf074566646f886/tensorrt_llm/_torch/models/modeling_deepseekv3.py#L506) | | Optimize MLA Kernel | 80 | [PR #3763](https://github.com/NVIDIA/TensorRT-LLM/pull/3763) | -| Optimize TopK Kernels | 84 | • [RoutingKernel.cu](https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/trtllmGenSrc/RoutingKernel.cu)<br/>• [noAuxTcKernels.cu](https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu) | +| Optimize TopK Kernels | 84 | • [RoutingKernelTopK.cuh](https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh)<br/>• [noAuxTcKernels.cu](https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu) | | Optimize Fuse_A_GEMM | 89 | [attention.py#L345](https://github.com/NVIDIA/TensorRT-LLM/blob/d6b741ddfe7f8a80718c10d49773c42abc0a254f/tensorrt_llm/_torch/modules/attention.py#L345) | | MTP3_Vanilla | 154 | evolve to MTP3_Autoregressive | | Evolve to MTP3_Autoregressive + Optimize Router GEMM | 164 | [modeling_deepseekv3.py#L304](https://github.com/NVIDIA/TensorRT-LLM/blob/d6b741ddfe7f8a80718c10d49773c42abc0a254f/tensorrt_llm/_torch/models/modeling_deepseekv3.py#L304) | diff --git a/docs/source/blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.md b/docs/source/blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.md index 8c9e935585..8d7682c482 100644 --- a/docs/source/blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.md +++ b/docs/source/blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.md @@ -277,7 +277,7 @@ We also conducted performance evaluations of Qwen 3 on GB200 GPUs. The data indi ### Reproducing Steps -We provide a set of scripts to reproduce the performance data presented in this paper. Please refer to the usage instructions described in [this document](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/disaggregated/slurm). +We provide a set of scripts to reproduce the performance data presented in this paper. Please refer to the usage instructions described in [this document](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/disaggregated/slurm/benchmark). ## Future Work diff --git a/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md b/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md index b50171ddf7..8f5c1dfec0 100644 --- a/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md +++ b/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md @@ -1,30 +1,29 @@ # Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM -In the guide below, we will walk you through how to launch your own +NVIDIA has [announced](https://developer.nvidia.com/blog/delivering-1-5-m-tps-inference-on-nvidia-gb200-nvl72-nvidia-accelerates-openai-gpt-oss-models-from-cloud-to-edge/) day-0 support for OpenAI's new open-source model series, [gpt-oss](https://openai.com/index/introducing-gpt-oss/). In the guide below, we will walk you through how to launch your own high-performance TensorRT-LLM server for **gpt-oss-120b** for inference. -This guide covers both low-latency and max-throughput cases. -The typical use case for **low-latency**, is when we try to maximize the number of tokens per second per user with a limited concurrency (4, 8 or 16 users). - -For **maximum throughput**, the goal is to maximize the amount of tokens produced per GPU per second. The former is an indication of how fast a system can produce tokens, the latter measures how many tokens a "chip" can generate per unit of time. +**Low-latency** use cases aim to maximize the number of tokens per second per user (tps/user) with limited concurrency. +For **max-throughput**, the goal is to maximize the tokens produced per GPU per second (tps/gpu). While tps/user indicates user experience quality, tps/gpu measures the economic efficiency of the system. ## Prerequisites -- 1x NVIDIA B200/GB200/H200 GPU (8x NVIDIA B200/H200 GPUs or 4x GB200 GPUs in a single node recommended for higher performance) -- CUDA Toolkit 12.8 or later -- Docker with NVIDIA Container Toolkit installed +- 1x NVIDIA B200/GB200/H200 GPU (more GPUs could be used for lower latency and higher throughput) - Fast SSD storage for model weights - Access to the gpt-oss-120b model checkpoint -We have a forthcoming guide for getting great performance on H100, however this guide focuses on the above GPUs. +We have a forthcoming guide for achieving great performance on H100; however, this guide focuses on the GPUs listed above. +## Install TensorRT-LLM -## Launching the TensorRT-LLM docker container +In this section, we introduce several ways to install TensorRT-LLM. -The container image that you will use will be pulled from NVIDIA's NGC. This container is multi-platform and will run on both x64 and arm64 architectures: `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev` +### NGC Docker Image of dev branch -Run the follow docker command to start the TensorRT-LLM container in interactive mode: +Day-0 support for gpt-oss is provided via the NGC container image `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev`. This image was built on top of the pre-day-0 **dev branch**. This container is multi-platform and will run on both x64 and arm64 architectures. + +Run the following docker command to start the TensorRT-LLM container in interactive mode: ```bash docker run --rm --ipc=host -it \ @@ -33,117 +32,134 @@ docker run --rm --ipc=host -it \ --gpus all \ -p 8000:8000 \ -e TRTLLM_ENABLE_PDL=1 \ - -e TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \ -v ~/.cache:/root/.cache:rw \ nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev \ /bin/bash ``` -This command: +Explanation of the command: - Automatically removes the container when stopped (`--rm`) - Allows container to interact with the host's IPC resources and shared memory for optimal performance (`--ipc=host`) - Runs the container in interactive mode (`-it`) - Sets up shared memory and stack limits for optimal performance -- Maps port 8000 from the container to your host -- enables PDL for low-latency perf optimization -- disables parallel weight loading +- Maps port 8000 from the container to the host +- Enables PDL for performance optimization -Lastly the container mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. +Additionally, the container mounts your user `.cache` directory to save the downloaded model checkpoints, which are stored in `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. You can also download the weights to a custom location (we assume `${local_model_path}` is the path to the local model weights). + +### Build from source + +Support for gpt-oss has been [merged](https://github.com/NVIDIA/TensorRT-LLM/pull/6645) into the **main branch** of TensorRT-LLM. As we continue to optimize gpt-oss performance, you can build TensorRT-LLM from source to get the latest features and support. Please refer to the [doc](https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html) if you want to build from source yourself. -## Running the TensorRT-LLM Server +### Regular Release of TensorRT-LLM -As pointed out in the introduction, this guide covers low-latency and max-throughput cases. Each requires a different configurations and commands to run. We will first cover the Low-Latency use-case, followed by the max throughput use-case. +Since gpt-oss has been supported on the main branch, you can get TensorRT-LLM out of the box through its regular release in the future. Please check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status. The release is provided as [NGC Container Image](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags) or [pip Python wheel](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html). + + +## Performance Benchmarking and Model Serving + +This guide covers how to configure for both low-latency and max-throughput cases, as well as how to benchmark end-to-end performance. + +### Prepare the dataset +Before getting started, we need to prepare a dataset of randomized tokens for benchmarking: + +```bash +python benchmarks/cpp/prepare_dataset.py \ + --stdout \ + --tokenizer openai/gpt-oss-120b \ + token-norm-dist \ + --input-mean 1024 \ + --output-mean 2048 \ + --input-stdev 0 \ + --output-stdev 0 \ + --num-requests 20000 > gpt-oss-120b-1k2k.txt +``` + +### Low-latency Use Case + +The low-latency configuration maximizes tps/user under limited concurrency (e.g., 1, 4, 8, or 16 users). Please set the number of GPUs and concurrency according to your specific situation and workload. + +```bash +num_gpus=8 +max_batch_size=1 +``` -### Low-latency Use-Case #### Creating the Extra Options Configuration -To run a server for low-latency workloads, create a YAML configuration file, `low_latency.yaml`, as follows: +Create a YAML configuration file, `low_latency.yaml`, as follows: -```yaml +```bash cat <<EOF > low_latency.yaml enable_attention_dp: false -enable_mixed_sampler: true cuda_graph_config: - max_batch_size: 8 + max_batch_size: ${max_batch_size} enable_padding: true moe_config: backend: TRTLLM EOF ``` -> Note: If you are using NVIDIA H200 GPUs it is highly recommended to set the `moe_config.backend` to TRITON to use the OpenAI Triton MoE kernel. See the section [(H200 Only) Using OpenAI Triton Kernels for MoE](#h200-only-using-openai-triton-kernels-for-moe) for more details. +Key takeaways: +- `enable_attention_dp` is set to `false` to use TP instead of DP for attention. +s- `cuda_graph_config.max_batch_size` is the maximum batch size for CUDA graph. +- `cuda_graph_config.enable_padding` is set to `true` to enable CUDA graph padding. +- `moe_config.backend` is set to `TRTLLM` to use the `trtllm-gen` MoE kernels which are optimized for low concurrency. -#### Launching TensorRT-LLM Serve +> Note: If you are using NVIDIA H200 GPUs please set the `moe_config.backend` to `TRITON` to use the OpenAI Triton MoE kernel regardless of use case. See the section [(H200/H100 Only) Using OpenAI Triton Kernels for MoE](#h200h100-only-using-openai-triton-kernels-for-moe) for more details. -To launch the TensorRT-LLM Server to serve the model with the **low latency** config, run the following command. Commands for different GPU configurations are provided (1xGPU, 8xGPU, 4xGPU): -<details open> <summary>1x B200/GB200/H200</summary> +#### Run the benchmark +Use `trtllm-bench` to benchmark the performance of your system: ```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 1 \ - --ep_size 1 \ - --trust_remote_code \ - --extra_llm_api_options low_latency.yaml \ - --kv_cache_free_gpu_memory_fraction 0.75 +trtllm-bench \ + --model openai/gpt-oss-120b \ + --model_path ${local_model_path} \ + throughput \ + --backend pytorch \ + --tp ${num_gpus} \ + --ep 1 \ + --extra_llm_api_options low_latency.yaml \ + --dataset gpt-oss-120b-1k2k.txt \ + --max_batch_size ${max_batch_size} \ + --concurrency ${max_batch_size} \ + --num_requests $((max_batch_size * 10)) \ + --kv_cache_free_gpu_mem_fraction 0.9 \ + --streaming \ + --warmup 0 \ + --report_json low_latency_benchmark.json ``` -</details> -<details> <summary>8x B200/H200</summary> +`--max_batch_size` controls the maximum batch size that the inference engine could serve, while `--concurrency` is the number of concurrent requests that the benchmarking client is sending. `--num_requests` is set to 10 times of `--concurrency` to run enough number of requests. + +Note that you can set `--ep` to a value larger than 1, which will enable mixed TP/EP for MoE. In minimum-latency scenarios, we recommend a small EP size to avoid load imbalance in MoE. + +For reference, we achieve **420 tps/user** with 8x B200 GPUs and batch size 1. + + +### Max-Throughput Use Case + +The max-throughput configuration maximizes tps/gpu at high concurrency levels. With increasing concurrency, we trade per-user latency for higher throughput that saturates the system's GPUs. Using input sequence length (isl) of 1k and output sequence length (osl) of 2k, we can currently achieve a batch size of 640 with 8x B200 GPUs. ```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 8 \ - --ep_size 8 \ - --trust_remote_code \ - --extra_llm_api_options low_latency.yaml \ - --kv_cache_free_gpu_memory_fraction 0.75 +num_gpus=8 +max_batch_size=640 ``` -</details> -<details> <summary>4x GB200/B200/H200</summary> - -```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 4 \ - --ep_size 4 \ - --trust_remote_code \ - --extra_llm_api_options low_latency.yaml \ - --kv_cache_free_gpu_memory_fraction 0.75 -``` -</details> - - - - -### Max-Throughput Use-Case #### Creating the Extra Options Configuration -To run a server for max-throughput workloads, create a YAML configuration file, -`max_throughput.yaml`, as follows: +Like before, create a YAML configuration file, `max_throughput.yaml`, as follows: -```yaml +```bash cat <<EOF > max_throughput.yaml enable_attention_dp: true cuda_graph_config: - max_batch_size: 640 + max_batch_size: ${max_batch_size} enable_padding: true stream_interval: 10 moe_config: @@ -151,97 +167,85 @@ moe_config: EOF ``` -> Note: If you are using NVIDIA H200 GPUs it is highly recommended to set the `moe_config.backend` to TRITON to use the OpenAI Triton MoE kernel. See the section [(H200 Only) Using OpenAI Triton Kernels for MoE](#h200-only-using-openai-triton-kernels-for-moe) for more details. +Compared to the low-latency configuration, we: +- set `enable_attention_dp` to `true` to use attention DP which is better for high throughput. +- set `stream_interval` to 10 to stream results to the client every 10 tokens. At high concurrency, the detokenization overhead of streaming mode cannot be hidden under GPU execution time, so `stream_interval` serves as a workaround to reduce this overhead. +- set `moe_config.backend` to `CUTLASS` to use the `CUTLASS` MoE kernels which are optimized for high throughput. -#### Launching TensorRT-LLM Serve +#### Run the benchmark -To launch the TensorRT-LLM Server to serve the model with the **max throughput** config, run the following command. Commands for different GPU configurations are provided (1xGPU, 8xGPU, 4xGPU): - -<details open> <summary>1x B200/GB200/H200</summary> +Run the following command to benchmark the throughput of your system: ```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ +trtllm-bench \ + --model openai/gpt-oss-120b \ + --model_path ${local_model_path} \ + throughput \ + --backend pytorch \ + --tp ${num_gpus} \ + --ep ${num_gpus} \ + --extra_llm_api_options max_throughput.yaml \ + --dataset gpt-oss-120b-1k2k.txt \ + --max_batch_size ${max_batch_size} \ + --concurrency $((max_batch_size * num_gpus)) \ + --num_requests $((max_batch_size * num_gpus * 3)) \ + --kv_cache_free_gpu_mem_fraction 0.9 \ + --streaming \ + --warmup 0 \ + --report_json max_throughput_benchmark.json +``` + +Note: +- `CUTLASS` MoE backend only supports pure EP for MoE, so we set `--ep` to `num_gpus`. +- When using `enable_attention_dp`, `max_batch_size` describes the maximum batch size for each local rank, so to saturate the system, we need to multiply `max_batch_size` by `num_gpus` for `--concurrency`. +- `--num_requests` is set to 3 times `--concurrency` to run enough number of requests. + +Currently, the best throughput **19.5k tps/gpu** is achieved with DP4EP4 using 4x B200 GPUs and over **20k tps/gpu** on GB200 GPUs due to slightly better performance of GB200, which translates to over **1.5M tps** on a GB200 NVL72 system. In theory, even better tps/gpu could be achieved with larger world size due to larger allowable batch size and smaller MoE weights per-GPU, but the communication implementation for >4GPUs is suboptimal and we are actively working on improving it. + + + +## Launch the TensorRT-LLM Server + +We can use `trtllm-serve` to serve the model by translating the benchmark commands above. For low-latency configuration, run: + +```bash +trtllm-serve \ + gpt-oss-120b \ # Or ${local_model_path} --host 0.0.0.0 \ --port 8000 \ --backend pytorch \ - --tp_size 1 \ - --ep_size 1 \ - --max_batch_size 640 \ - --trust_remote_code \ - --extra_llm_api_options max_throughput.yaml \ - --kv_cache_free_gpu_memory_fraction 0.9 + --tp_size ${num_gpus} \ + --ep_size 1 \ + --extra_llm_api_options low_latency.yaml \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --max_batch_size ${max_batch_size} \ # E.g., 1 + --trust_remote_code ``` -</details> - -<details> <summary>8x B200/H200</summary> - -```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 8 \ - --ep_size 8 \ - --max_batch_size 640 \ - --trust_remote_code \ - --extra_llm_api_options max_throughput.yaml \ - --kv_cache_free_gpu_memory_fraction 0.9 -``` -</details> - -<details> <summary>4x GB200/B200/H200</summary> - -```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 4 \ - --ep_size 4 \ - --max_batch_size 640 \ - --trust_remote_code \ - --extra_llm_api_options max_throughput.yaml \ - --kv_cache_free_gpu_memory_fraction 0.9 -``` -</details> - - -This command: -- Maps port 8000 from the container to your host -- Uses the PyTorch backend and specifies the tensor and expert parallel sizes -- References the low latency or max throughput configuration file for extra options -- Configures memory settings for optimal performance -- Enables all GPUs with attention data parallelism for the max throughput scenario The initialization may take several minutes as it loads and optimizes the models. - -## (H200 Only) Using OpenAI Triton Kernels for MoE - -OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper based GPUs like NVIDIA's H200 for best performance. The NGC TensorRT-LLM container image mentioned above already includes the required kernels so you do not need to build or install them. It is highly recommended to enable them with the steps below: - -### Selecting Triton as the MoE backend - -To use the Triton MoE backend with **trtllm-serve** (or other similar commands) add this snippet to the YAML file passed via `--extra_llm_api_options`: - -```yaml -moe_config: - backend: TRITON -``` - -Alternatively the TRITON backend can be enabled by passing the CLI flag to the trtllm-server command at runtime: +For max-throughput configuration, run: ```bash ---moe_backend TRITON +trtllm-serve \ + gpt-oss-120b \ # Or ${local_model_path} + --host 0.0.0.0 \ + --port 8000 \ + --backend pytorch \ + --tp_size ${num_gpus} \ + --ep_size ${num_gpus} \ + --extra_llm_api_options max_throughput.yaml \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --max_batch_size ${max_batch_size} \ # E.g., 640 + --trust_remote_code ``` -## Test the Server with a Sample Request -You can query the health/readiness of the server using +### Test the Server with a Sample Request + + +To check the server's health and readiness: ```bash curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health" @@ -252,14 +256,13 @@ very first query may take longer due to initialization and compilation. Once the server is running, you can test it with a simple curl request: - ```bash curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "openai/gpt-oss-120b", "messages": [ { "role": "user", - "content": "What is NVIDIAs advantage for inference?" + "content": "What is NVIDIA's advantage for inference?" } ], "max_tokens": 1024, @@ -343,20 +346,29 @@ requests. You can adjust parameters like `max_tokens`, `temperature`, and others according to your needs. +## (H200/H100 Only) Using OpenAI Triton Kernels for MoE + +OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please enable `TRITON` backend with the steps below if you are running on Hopper GPUs. + +### Installing OpenAI Triton + +The `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev` has prepared Triton already (`echo $TRITON_ROOT` could reveal the path). In other situations, you will need to build and install a specific version of Triton. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe). + + +### Selecting Triton as the MoE backend + +To use the Triton MoE backend with **trtllm-serve** (or other commands), add this snippet to the YAML file passed via `--extra_llm_api_options`: + +```yaml +moe_config: + backend: TRITON +``` + + ## Troubleshooting Tips -- If you encounter CUDA out-of-memory errors, try reducing `max_batch_size`, `max_seq_len`, or `--kv_cache_free_gpu_memory_fraction` -- Ensure your model checkpoints are compatible with the expected format -- For performance issues, check GPU utilization with `nvidia-smi` while the server is running +- If you encounter CUDA out-of-memory errors, try reducing `--max_batch_size`, `--max_num_tokens`, or `--kv_cache_free_gpu_memory_fraction`. See the [doc](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md) for the explanation of these parameters. +- Add `print_iter_log: true` to extra LLM API options YAML file to inspect the per-iteration log. +- Check GPU utilization with `nvidia-smi` while the server is running to inspect GPU status and memory usage. - If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed - For connection issues, make sure port 8000 is not being used by another application - - -## Performance Tuning - -The configuration provided is optimized for 8xB200 GPUs, but you can adjust -several parameters for your specific workload: - -- `max_batch_size`: Controls how many requests can be batched together -- `max_draft_len`: The number of tokens Eagle can speculate ahead -- `kv_cache_free_gpu_memory_fraction`: Controls memory allocation for the KV cache diff --git a/docs/source/commands/trtllm-serve/run-benchmark-with-trtllm-serve.md b/docs/source/commands/trtllm-serve/run-benchmark-with-trtllm-serve.md index 161535e96e..eee15e2b68 100644 --- a/docs/source/commands/trtllm-serve/run-benchmark-with-trtllm-serve.md +++ b/docs/source/commands/trtllm-serve/run-benchmark-with-trtllm-serve.md @@ -3,11 +3,12 @@ TensorRT-LLM provides the OpenAI-compatiable API via `trtllm-serve` command. A complete reference for the API is available in the [OpenAI API Reference](https://platform.openai.com/docs/api-reference). -This step-by-step tutorial covers the following topics for running online serving benchmarking with Llama 3.1 70B: +This step-by-step tutorial covers the following topics for running online serving benchmarking with Llama 3.1 70B and Qwen2.5-VL-7B for multimodal models: * Methodology Introduction * Launch the OpenAI-Compatibale Server with NGC container * Run the performance benchmark * Using `extra_llm_api_options` + * Multimodal Serving and Benchmarking ## Methodology Introduction @@ -220,3 +221,78 @@ The following is a list of common performance switches.  **Default**: TRTLLM See the [TorchLlmArgs class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the extra\_llm\_api\_options`.` + +## Multimodal Serving and Benchmarking + +TensorRT-LLM supports multimodal models for both serving and benchmarking. This section covers how to set up multimodal serving and run benchmarks for multimodal models. + +### Setting up Multimodal Serving + +Here's an example of setting up multimodal serving with Qwen2.5-VL: + +```bash +#!/bin/bash +model_path=/path/to/qwen2.5vl-7B_model + +trtllm-serve ${model_path} \ + --max_batch_size 64 \ + --max_num_tokens 8192 \ + --max_seq_len 4096 \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --tp_size 1 \ + --ep_size 1 \ + --trust_remote_code +``` + +### Multimodal Benchmarking + +For multimodal serving benchmarks, you can use the `benchmark_serving.py` script with multimodal datasets: + +```bash +python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model ${model_path} \ + --backend openai-chat \ + --dataset-name "random_image" \ + --random-input-len 128 \ + --random-output-len 128 \ + --random-image-width 512 \ + --random-image-height 512 \ + --random-num-images 1 \ + --num-prompts 100 \ + --max-concurrency 8 \ + --ignore-eos +``` + +Below is some example TensorRT-LLM serving benchmark output. Your actual results may vary. +``` +============ Serving Benchmark Result ============ +Successful requests: 1 +Benchmark duration (s): 0.83 +Total input tokens: 128 +Total generated tokens: 128 +Request throughput (req/s): 1.20 +Output token throughput (tok/s): 153.92 +Total Token throughput (tok/s): 307.85 +User throughput (tok/s): 154.15 +Mean Request AR: 0.9845 +Median Request AR: 0.9845 +---------------Time to First Token---------------- +Mean TTFT (ms): 84.03 +Median TTFT (ms): 84.03 +P99 TTFT (ms): 84.03 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 5.88 +Median TPOT (ms): 5.88 +P99 TPOT (ms): 5.88 +---------------Inter-token Latency---------------- +Mean ITL (ms): 5.83 +Median ITL (ms): 5.88 +P99 ITL (ms): 6.14 +================================================== +``` + +**Notes for Multimodal Benchmarking:** +- Set `--backend` as `openai-chat` since multimodal models are only supported on the chat API and require a chat template +- Control the number of images per request with `--random-num-images` +- Use `--random-image-width` and `--random-image-height` to specify image dimensions or `--random-image-size` for squared image dimensions. +- The `random_image` dataset generates synthetic images for benchmarking diff --git a/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md new file mode 100644 index 0000000000..dd9899e2a3 --- /dev/null +++ b/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md @@ -0,0 +1,397 @@ +# Quick Start Recipe for DeepSeek R1 on TensorRT-LLM - Blackwell & Hopper Hardware + +## Introduction + +This deployment guide provides step-by-step instructions for running the DeepSeek R1 model using TensorRT-LLM with FP8 and NVFP4 quantization, optimized for NVIDIA GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring TensorRT-LLM parameters, launching the server, and validating inference output. + +The guide is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIA’s accelerated stack—starting with the PyTorch container from NGC, then installing TensorRT-LLM for model serving, FlashInfer for optimized CUDA kernels, and ModelOpt to enable FP8 and NVFP4 quantized execution. + +## Prerequisites + +* GPU: NVIDIA Blackwell or Hopper Architecture +* OS: Linux +* Drivers: CUDA Driver 575 or Later +* Docker with NVIDIA Container Toolkit installed +* Python3 and python3-pip (Optional, for accuracy evaluation only) + +## Models + +* FP8 model: [DeepSeek-R1-0528](https://huggingface.co/deepseek-ai/DeepSeek-R1-0528) +* NVFP4 model: [DeepSeek-R1-0528-FP4](https://huggingface.co/nvidia/DeepSeek-R1-0528-FP4) + + +## MoE Backend Support Matrix + +There are multiple MOE backends inside TRT-LLM, not all of them supporting every precision on every GPUs. Here are the support matrix of the MOE backends. + +| device | Checkpoint | Supported moe_backend | +|----------|----------|----------| +| H100/H200 | FP8 | CUTLASS | +| B200/GB200 EP<=8 | NVFP4 | CUTLASS, TRTLLM | +| B200/GB200 EP<=8 | FP8 | DEEPGEMM | +| GB200 NVL72 EP>8 | NVFP4 | WIDEEP | +| GB200 NVL72 EP>8 | FP8 | N/A (WIP) | + +The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model. + +## Deployment Steps + +### Run Docker Container + +Run the docker container using the TensorRT-LLM NVIDIA NGC image. + +```shell +docker run --rm -it \ +--ipc=host \ +--gpus all \ +-p 8000:8000 \ +-v ~/.cache:/root/.cache:rw \ +--name tensorrt_llm \ +nvcr.io/nvidia/tensorrt-llm/release:1.0.0rc6 \ +/bin/bash +``` + +Note: + +* The command mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. If the `~/.cache` directory doesn’t exist please create it using `$ mkdir ~/.cache`. +* You can mount additional directories and paths using the `-v <host_path>:<container_path>` flag if needed, such as mounting the downloaded weight paths. +* The command also maps port `8000` from the container to your host so you can access the LLM API endpoint from your host +* See the <https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags> for all the available containers. The containers published in the main branch weekly have `rcN` suffix, while the monthly release with QA tests has no `rcN` suffix. Use the `rc` release to get the latest model and feature support. + +If you want to use latest main branch, you can choose to build from source to install TensorRT-LLM, the steps refer to <https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html>. + +### Creating the TRT-LLM Server config + +We create a YAML configuration file `/tmp/config.yml` for the TensorRT-LLM Server and populate it with the following recommended performance settings. + +```shell +EXTRA_LLM_API_FILE=/tmp/config.yml + +cat << EOF > ${EXTRA_LLM_API_FILE} +enable_attention_dp: true +cuda_graph_config: + enable_padding: true + max_batch_size: 128 +kv_cache_config: + dtype: fp8 +stream_interval: 10 +speculative_config: + decoding_type: MTP + num_nextn_predict_layers: 1 +EOF +``` + +For FP8 model, we need extra `moe_config`: + +```shell +EXTRA_LLM_API_FILE=/tmp/config.yml + +cat << EOF > ${EXTRA_LLM_API_FILE} +enable_attention_dp: true +cuda_graph_config: + enable_padding: true + max_batch_size: 128 +kv_cache_config: + dtype: fp8 +stream_interval: 10 +speculative_config: + decoding_type: MTP + num_nextn_predict_layers: 1 +moe_config: + backend: DEEPGEMM + max_num_tokens: 3200 +EOF +``` + +### Launch the TRT-LLM Server + +Below is an example command to launch the TRT-LLM server with the DeepSeek-R1 model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section. + +```shell +trtllm-serve deepseek-ai/DeepSeek-R1-0528 \ + --host 0.0.0.0 \ + --port 8000 \ + --backend pytorch \ + --max_batch_size 1024 \ + --max_num_tokens 3200 \ + --max_seq_len 2048 \ + --kv_cache_free_gpu_memory_fraction 0.8 \ + --tp_size 8 \ + --ep_size 8 \ + --trust_remote_code \ + --extra_llm_api_options ${EXTRA_LLM_API_FILE} +``` + +After the server is set up, the client can now send prompt requests to the server and receive results. + +### Configs and Parameters + +These options are used directly on the command line when you start the `trtllm-serve` process. + +#### `--tp_size` + +* **Description:** Sets the **tensor-parallel size**. This should typically match the number of GPUs you intend to use for a single model instance. + +#### `--ep_size` + +* **Description:** Sets the **expert-parallel size** for Mixture-of-Experts (MoE) models. Like `tp_size`, this should generally match the number of GPUs you're using. This setting has no effect on non-MoE models. + +#### `--kv_cache_free_gpu_memory_fraction` + +* **Description:** A value between `0.0` and `1.0` that specifies the fraction of free GPU memory to reserve for the KV cache after the model is loaded. Since memory usage can fluctuate, this buffer helps prevent out-of-memory (OOM) errors. +* **Recommendation:** If you experience OOM errors, try reducing this value to `0.7` or lower. + +#### `--backend pytorch` + +* **Description:** Tells TensorRT-LLM to use the **pytorch** backend. + +#### `--max_batch_size` + +* **Description:** The maximum number of user requests that can be grouped into a single batch for processing. + +#### `--max_num_tokens` + +* **Description:** The maximum total number of tokens (across all requests) allowed inside a single scheduled batch. + +#### `--max_seq_len` + +* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens. + +#### `--trust_remote_code` + +* **Description:** Allows TensorRT-LLM to download models and tokenizers from Hugging Face. This flag is passed directly to the Hugging Face API. + + +#### Extra LLM API Options (YAML Configuration) + +These options provide finer control over performance and are set within a YAML file passed to the `trtllm-serve` command via the `--extra_llm_api_options` argument. + +#### `kv_cache_config` + +* **Description**: A section for configuring the Key-Value (KV) cache. + +* **Options**: + + * `dtype`: Sets the data type for the KV cache. + **Default**: `"auto"` (uses the data type specified in the model checkpoint). + +#### `cuda_graph_config` + +* **Description**: A section for configuring CUDA graphs to optimize performance. + +* **Options**: + + * `enable_padding`: If `"true"`, input batches are padded to the nearest `cuda_graph_batch_size`. This can significantly improve performance. + + **Default**: `false` + + * `max_batch_size`: Sets the maximum batch size for which a CUDA graph will be created. + + **Default**: `0` + + **Recommendation**: Set this to the same value as the `--max_batch_size` command-line option. + + * `batch_sizes`: A specific list of batch sizes to create CUDA graphs for. + + **Default**: `None` + +#### `moe_config` + +* **Description**: Configuration for Mixture-of-Experts (MoE) models. + +* **Options**: + + * `backend`: The backend to use for MoE operations. + **Default**: `CUTLASS` + +#### `attention_backend` + +* **Description**: The backend to use for attention calculations. + +* **Default**: `TRTLLM` + +See the [`TorchLlmArgs` class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the `extra_llm_api_options`. + +## Testing API Endpoint + +### Basic Test + +Start a new terminal on the host to test the TensorRT-LLM server you just launched. + +You can query the health/readiness of the server using: + +```shell +curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health" +``` + +When the `Status: 200` code is returned, the server is ready for queries. Note that the very first query may take longer due to initialization and compilation. + +After the TRT-LLM server is set up and shows Application startup complete, you can send requests to the server. + +```shell +curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ + "model": "deepseek-ai/DeepSeek-R1-0528", + "prompt": "Where is New York?", + "max_tokens": 16, + "temperature": 0 +}' +``` + +Here is an example response, showing that the TRT-LLM server returns “New York is a state located in the northeastern United States. It is bordered by”, completing the input sequence. + +```json +{"id":"cmpl-e728f08114c042309efeae4df86a50ca","object":"text_completion","created":1754294810,"model":"deepseek-ai/DeepSeek-R1-0528","choices":[{"index":0,"text":" / by Megan Stine ; illustrated by John Hinderliter.\n\nBook | Gross","token_ids":null,"logprobs":null,"context_logits":null,"finish_reason":"length","stop_reason":null,"disaggregated_params":null}],"usage":{"prompt_tokens":6,"total_tokens":22,"completion_tokens":16},"prompt_token_ids":null} +``` + +### Troubleshooting Tips + +* If you encounter CUDA out-of-memory errors, try reducing `max_batch_size` or `max_seq_len`. + * For running input/output sequence lengths of 8K/1K on H200, there is a known CUDA Out-Of-Memory issue caused by the PyTorch CUDA Caching Allocator fragmenting memory. As a workaround, you can set the environment variable `PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:8192`. For more details, please refer to the [PyTorch documentation on optimizing memory usage](https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf). +* Ensure your model checkpoints are compatible with the expected format. +* For performance issues, check GPU utilization with nvidia-smi while the server is running. +* If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed. +* For connection issues, make sure the server port (`8000` in this guide) is not being used by another application. + +### Running Evaluations to Verify Accuracy (Optional) + +We use the `lm-eval` tool to test the model’s accuracy. For more information see [https://github.com/EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). + +To run the evaluation harness exec into the running TensorRT-LLM container and install with this command: + +```shell +docker exec -it tensorrt_llm /bin/bash + +pip install lm_eval +``` + +FP8 command for GSM8K: + +* Note: The tokenizer will add BOS (beginning of sentence token) before input prompt by default which leads to accuracy regression on GSM8K task for DeepSeek R1 model. So, set `add_special_tokens=False` to avoid it. + +```shell +MODEL_PATH=deepseek-ai/DeepSeek-R1-0528 + +lm_eval --model local-completions --tasks gsm8k --batch_size 256 --gen_kwargs temperature=0.0,add_special_tokens=False --num_fewshot 5 --model_args model=${MODEL_PATH},base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=20,tokenized_requests=False --log_samples --output_path trtllm.fp8.gsm8k +``` + +Sample result in Blackwell: + +``` +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9538|± |0.0058| +| | |strict-match | 5|exact_match|↑ |0.9500|± |0.0060| +``` + +FP4 command for GSM8K: + +* Note: The tokenizer will add BOS before input prompt by default, which leads to accuracy regression on GSM8K task for DeepSeek R1 model. So set `add_special_tokens=False` to avoid it. + +```shell +MODEL_PATH=nvidia/DeepSeek-R1-0528-FP4 + +lm_eval --model local-completions --tasks gsm8k --batch_size 256 --gen_kwargs temperature=0.0,add_special_tokens=False --num_fewshot 5 --model_args model=${MODEL_PATH},base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=20,tokenized_requests=False --log_samples --output_path trtllm.fp4.gsm8k +``` + +Sample result in Blackwell: + +``` +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9462|± |0.0062| +| | |strict-match | 5|exact_match|↑ |0.9447|± |0.0063| +``` + +## Benchmarking Performance + +To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script. + +```shell +cat <<EOF > bench.sh +concurrency_list="32 64 128 256 512 1024 2048 4096" +multi_round=5 +isl=1024 +osl=1024 +result_dir=/tmp/deepseek_r1_output + +for concurrency in ${concurrency_list}; do + num_prompts=$((concurrency * multi_round)) + python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model deepseek-ai/DeepSeek-R1-0528 \ + --backend openai \ + --dataset-name "random" \ + --random-input-len ${isl} \ + --random-output-len ${osl} \ + --random-prefix-len 0 \ + --random-ids \ + --num-prompts ${num_prompts} \ + --max-concurrency ${concurrency} \ + --ignore-eos \ + --tokenize-on-client \ + --percentile-metrics "ttft,tpot,itl,e2el" +done +EOF +chmod +x bench.sh +``` + +To benchmark the FP4 model, replace `--model deepseek-ai/DeepSeek-R1-0528` with `--model nvidia/DeepSeek-R1-0528-FP4`. + +If you want to save the results to a file add the following options. + +```shell +--save-result \ +--result-dir "${result_dir}" \ +--result-filename "concurrency_${concurrency}.json" +``` + +For more benchmarking options see <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt\_llm/serve/scripts/benchmark\_serving.py>. + +Run `bench.sh` to begin a serving benchmark. This will take a long time if you run all the concurrencies mentioned in the above `bench.sh` script. + +```shell +./bench.sh +``` + +Sample TensorRT-LLM serving benchmark output. Your results may vary due to ongoing software optimizations. + +``` +============ Serving Benchmark Result ============ +Successful requests: 16 +Benchmark duration (s): 17.66 +Total input tokens: 16384 +Total generated tokens: 16384 +Request throughput (req/s): [result] +Output token throughput (tok/s): [result] +Total Token throughput (tok/s): [result] +User throughput (tok/s): [result] +---------------Time to First Token---------------- +Mean TTFT (ms): [result] +Median TTFT (ms): [result] +P99 TTFT (ms): [result] +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): [result] +Median TPOT (ms): [result] +P99 TPOT (ms): [result] +---------------Inter-token Latency---------------- +Mean ITL (ms): [result] +Median ITL (ms): [result] +P99 ITL (ms): [result] +----------------End-to-end Latency---------------- +Mean E2EL (ms): [result] +Median E2EL (ms): [result] +P99 E2EL (ms): [result] +================================================== +``` + +### Key Metrics + +* Median Time to First Token (TTFT) + * The typical time elapsed from when a request is sent until the first output token is generated. +* Median Time Per Output Token (TPOT) + * The typical time required to generate each token *after* the first one. +* Median Inter-Token Latency (ITL) + * The typical time delay between the completion of one token and the completion of the next. +* Median End-to-End Latency (E2EL) + * The typical total time from when a request is submitted until the final token of the response is received. +* Total Token Throughput + * The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens. diff --git a/docs/source/deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.md new file mode 100644 index 0000000000..31168002cb --- /dev/null +++ b/docs/source/deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.md @@ -0,0 +1,362 @@ +# Quick Start Recipe for Llama3.3 70B on TensorRT-LLM - Blackwell & Hopper Hardware + +## Introduction + +This deployment guide provides step-by-step instructions for running the Llama 3.3-70B Instruct model using TensorRT-LLM with FP8 and NVFP4 quantization, optimized for NVIDIA GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring TensorRT-LLM parameters, launching the server, and validating inference output. + +The guide is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIA’s accelerated stack—starting with the PyTorch container from NGC, then installing TensorRT-LLM for model serving, FlashInfer for optimized CUDA kernels, and ModelOpt to enable FP8 and NVFP4 quantized execution. + +## Access & Licensing + +To use Llama 3.3-70B, you must first agree to Meta’s Llama 3 Community License ([https://ai.meta.com/resources/models-and-libraries/llama-downloads/](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)). NVIDIA’s quantized versions (FP8 and FP4) are built on top of the base model and are available for research and commercial use under the same license. + +## Prerequisites + +GPU: NVIDIA Blackwell or Hopper Architecture +OS: Linux +Drivers: CUDA Driver 575 or Later +Docker with NVIDIA Container Toolkit installed +Python3 and python3-pip (Optional, for accuracy evaluation only) + +## Models + +* FP8 model: [Llama-3.3-70B-Instruct-FP8](https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP8) +* NVFP4 model: [Llama-3.3-70B-Instruct-FP4](https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP4) + + +Note that NVFP4 is only supported on NVIDIA Blackwell + +## Deployment Steps + +### Run Docker Container + +Run the docker container using the TensorRT-LLM NVIDIA NGC image. + +```shell +docker run --rm -it \ +--ipc=host \ +--gpus all \ +-p 8000:8000 \ +-v ~/.cache:/root/.cache:rw \ +--name tensorrt_llm \ +nvcr.io/nvidia/tensorrt-llm/release:1.0.0rc6 \ +/bin/bash +``` + +Note: + +* The command mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. If the `~/.cache` directory doesn’t exist please create it using `$ mkdir ~/.cache`. +* You can mount additional directories and paths using the `-v <host_path>:<container_path>` flag if needed, such as mounting the downloaded weight paths. +* The command also maps port `8000` from the container to your host so you can access the LLM API endpoint from your host +* See the <https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags> for all the available containers. The containers published in the main branch weekly have `rcN` suffix, while the monthly release with QA tests has no `rcN` suffix. Use the `rc` release to get the latest model and feature support. + +If you want to use latest main branch, you can choose to build from source to install TensorRT-LLM, the steps refer to <https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html>. + +### Creating the TRT-LLM Server config + +We create a YAML configuration file `/tmp/config.yml` for the TensorRT-LLM Server and populate it with the following recommended performance settings. + +```shell +EXTRA_LLM_API_FILE=/tmp/config.yml + +cat << EOF > ${EXTRA_LLM_API_FILE} +enable_attention_dp: false +cuda_graph_config: + enable_padding: true + max_batch_size: 1024 +kv_cache_config: + dtype: fp8 +EOF +``` + +### Launch the TRT-LLM Server + +Below is an example command to launch the TRT-LLM server with the Llama-3.3-70B-Instruct-FP8 model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section. + +```shell +trtllm-serve nvidia/Llama-3.3-70B-Instruct-FP8 \ + --host 0.0.0.0 \ + --port 8000 \ + --backend pytorch \ + --max_batch_size 1024 \ + --max_num_tokens 2048 \ + --max_seq_len 2048 \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --tp_size 1 \ + --ep_size 1 \ + --trust_remote_code \ + --extra_llm_api_options ${EXTRA_LLM_API_FILE} +``` + +After the server is set up, the client can now send prompt requests to the server and receive results. + +### Configs and Parameters + +These options are used directly on the command line when you start the `trtllm-serve` process. + +#### `--tp_size` + +* **Description:** Sets the **tensor-parallel size**. This should typically match the number of GPUs you intend to use for a single model instance. + +#### `--ep_size` + +* **Description:** Sets the **expert-parallel size** for Mixture-of-Experts (MoE) models. Like `tp_size`, this should generally match the number of GPUs you're using. This setting has no effect on non-MoE models. + +#### `--kv_cache_free_gpu_memory_fraction` + +* **Description:** A value between `0.0` and `1.0` that specifies the fraction of free GPU memory to reserve for the KV cache after the model is loaded. Since memory usage can fluctuate, this buffer helps prevent out-of-memory (OOM) errors. +* **Recommendation:** If you experience OOM errors, try reducing this value to `0.8` or lower. + +#### `--backend pytorch` + +* **Description:** Tells TensorRT-LLM to use the **pytorch** backend. + +#### `--max_batch_size` + +* **Description:** The maximum number of user requests that can be grouped into a single batch for processing. + +#### `--max_num_tokens` + +* **Description:** The maximum total number of tokens (across all requests) allowed inside a single scheduled batch. + +#### `--max_seq_len` + +* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens. + +#### `--trust_remote_code` + +* **Description:** Allows TensorRT-LLM to download models and tokenizers from Hugging Face. This flag is passed directly to the Hugging Face API. + + +#### Extra LLM API Options (YAML Configuration) + +These options provide finer control over performance and are set within a YAML file passed to the `trtllm-serve` command via the `--extra_llm_api_options` argument. + +#### `kv_cache_config` + +* **Description**: A section for configuring the Key-Value (KV) cache. + +* **Options**: + + * `dtype`: Sets the data type for the KV cache. + **Default**: `"auto"` (uses the data type specified in the model checkpoint). + +#### `cuda_graph_config` + +* **Description**: A section for configuring CUDA graphs to optimize performance. + +* **Options**: + + * `enable_padding`: If `"true"`, input batches are padded to the nearest `cuda_graph_batch_size`. This can significantly improve performance. + + **Default**: `false` + + * `max_batch_size`: Sets the maximum batch size for which a CUDA graph will be created. + + **Default**: `0` + + **Recommendation**: Set this to the same value as the `--max_batch_size` command-line option. + + * `batch_sizes`: A specific list of batch sizes to create CUDA graphs for. + + **Default**: `None` + +#### `moe_config` + +* **Description**: Configuration for Mixture-of-Experts (MoE) models. + +* **Options**: + + * `backend`: The backend to use for MoE operations. + **Default**: `CUTLASS` + +#### `attention_backend` + +* **Description**: The backend to use for attention calculations. + +* **Default**: `TRTLLM` + +See the [`TorchLlmArgs` class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the `extra_llm_api_options`. + +## Testing API Endpoint + +### Basic Test + +Start a new terminal on the host to test the TensorRT-LLM server you just launched. + +You can query the health/readiness of the server using: + +```shell +curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health" +``` + +When the `Status: 200` code is returned, the server is ready for queries. Note that the very first query may take longer due to initialization and compilation. + +After the TRT-LLM server is set up and shows Application startup complete, you can send requests to the server. + +```shell +curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ + "model": "nvidia/Llama-3.3-70B-Instruct-FP8", + "prompt": "Where is New York?", + "max_tokens": 16, + "temperature": 0 +}' +``` + +Here is an example response, showing that the TRT-LLM server returns “New York is a state located in the northeastern United States. It is bordered by”, completing the input sequence. + +```json +{"id":"cmpl-bc1393d529ce485c961d9ffee5b25d72","object":"text_completion","created":1753843963,"model":"nvidia/Llama-3.3-70B-Instruct-FP8","choices":[{"index":0,"text":" New York is a state located in the northeastern United States. It is bordered by","token_ids":null,"logprobs":null,"context_logits":null,"finish_reason":"length","stop_reason":null,"disaggregated_params":null}],"usage":{"prompt_tokens":6,"total_tokens":22,"completion_tokens":16},"prompt_token_ids":null} +``` + +### Troubleshooting Tips + +* If you encounter CUDA out-of-memory errors, try reducing `max_batch_size` or `max_seq_len`. +* Ensure your model checkpoints are compatible with the expected format. +* For performance issues, check GPU utilization with nvidia-smi while the server is running. +* If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed. +* For connection issues, make sure the server port (`8000` in this guide) is not being used by another application. + +### Running Evaluations to Verify Accuracy (Optional) + +We use the lm-eval tool to test the model’s accuracy. For more information see <https://github.com/EleutherAI/lm-evaluation-harness>. + +To run the evaluation harness exec into the running TensorRT-LLM container and install with this command: + +```shell +docker exec -it tensorrt_llm /bin/bash + +pip install lm_eval +``` + +FP8 command for GSM8K + +* Note: The tokenizer will add BOS (beginning of sentence token) before input prompt by default which leads to accuracy regression on GSM8K task for Llama 3.3 70B instruction model. So, set `add_special_tokens=False` to avoid it. + +```shell +MODEL_PATH=nvidia/Llama-3.3-70B-Instruct-FP8 + +lm_eval --model local-completions --tasks gsm8k --batch_size 256 --gen_kwargs temperature=0.0,add_special_tokens=False --num_fewshot 5 --model_args model=${MODEL_PATH},base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=20,tokenized_requests=False --log_samples --output_path trtllm.fp8.gsm8k +``` + +Sample result in Blackwell. + +``` +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9348|± |0.0068| +| | |strict-match | 5|exact_match|↑ |0.8870|± |0.0087| +``` + +FP4 command for GSM8K + +* Note: The tokenizer will add BOS before input prompt by default, which leads to accuracy regression on GSM8K task for LLama 3.3 70B instruction model. So set `add_special_tokens=False` to avoid it. + +```shell +MODEL_PATH=nvidia/Llama-3.3-70B-Instruct-FP4 + +lm_eval --model local-completions --tasks gsm8k --batch_size 256 --gen_kwargs temperature=0.0,add_special_tokens=False --num_fewshot 5 --model_args model=${MODEL_PATH},base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=20,tokenized_requests=False --log_samples --output_path trtllm.fp4.gsm8k +``` + +Sample result in Blackwell + +``` +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9356|± |0.0068| +| | |strict-match | 5|exact_match|↑ |0.8393|± |0.0101| +``` + +## Benchmarking Performance + +To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script. + +```shell +cat <<EOF > bench.sh +concurrency_list="1 2 4 8 16 32 64 128 256" +multi_round=5 +isl=1024 +osl=1024 +result_dir=/tmp/llama3.3_output + +for concurrency in ${concurrency_list}; do + num_prompts=$((concurrency * multi_round)) + python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model nvidia/Llama-3.3-70B-Instruct-FP8 \ + --backend openai \ + --dataset-name "random" \ + --random-input-len ${isl} \ + --random-output-len ${osl} \ + --random-prefix-len 0 \ + --random-ids \ + --num-prompts ${num_prompts} \ + --max-concurrency ${concurrency} \ + --ignore-eos \ + --tokenize-on-client \ + --percentile-metrics "ttft,tpot,itl,e2el" +done +EOF +chmod +x bench.sh +``` + +To benchmark the FP4 model, replace `--model nvidia/Llama-3.3-70B-Instruct-FP8` with `--model nvidia/Llama-3.3-70B-Instruct-FP4`. + +If you want to save the results to a file add the following options. + +```shell +--save-result \ +--result-dir "${result_dir}" \ +--result-filename "concurrency_${concurrency}.json" +``` + +For more benchmarking options see <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt\_llm/serve/scripts/benchmark\_serving.py>. + +Run `bench.sh` to begin a serving benchmark. This will take a long time if you run all the concurrencies mentioned in the above `bench.sh` script. + +```shell +./bench.sh +``` + +Sample TensorRT-LLM serving benchmark output. Your results may vary due to ongoing software optimizations. + +``` +============ Serving Benchmark Result ============ +Successful requests: 16 +Benchmark duration (s): 17.66 +Total input tokens: 16384 +Total generated tokens: 16384 +Request throughput (req/s): [result] +Output token throughput (tok/s): [result] +Total Token throughput (tok/s): [result] +User throughput (tok/s): [result] +---------------Time to First Token---------------- +Mean TTFT (ms): [result] +Median TTFT (ms): [result] +P99 TTFT (ms): [result] +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): [result] +Median TPOT (ms): [result] +P99 TPOT (ms): [result] +---------------Inter-token Latency---------------- +Mean ITL (ms): [result] +Median ITL (ms): [result] +P99 ITL (ms): [result] +----------------End-to-end Latency---------------- +Mean E2EL (ms): [result] +Median E2EL (ms): [result] +P99 E2EL (ms): [result] +================================================== +``` + +### Key Metrics + +* Median Time to First Token (TTFT) + * The typical time elapsed from when a request is sent until the first output token is generated. +* Median Time Per Output Token (TPOT) + * The typical time required to generate each token *after* the first one. +* Median Inter-Token Latency (ITL) + * The typical time delay between the completion of one token and the completion of the next. +* Median End-to-End Latency (E2EL) + * The typical total time from when a request is submitted until the final token of the response is received. +* Total Token Throughput + * The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens. diff --git a/docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md new file mode 100644 index 0000000000..6ec972dc45 --- /dev/null +++ b/docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md @@ -0,0 +1,357 @@ +# Quick Start Recipe for Llama4 Scout 17B on TensorRT-LLM - Blackwell & Hopper Hardware + +## Introduction + +This deployment guide provides step-by-step instructions for running the Llama-4-Scout-17B-16E-Instruct model using TensorRT-LLM with FP8 and NVFP4 quantization, optimized for NVIDIA GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring TensorRT-LLM parameters, launching the server, and validating inference output. + +The guide is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIA’s accelerated stack—starting with the PyTorch container from NGC, then installing TensorRT-LLM for model serving, FlashInfer for optimized CUDA kernels, and ModelOpt to enable FP8 and NVFP4 quantized execution. + +## Access & Licensing + +To use Llama4 Scout 17B, you must first agree to Meta’s [Llama 4 Community License](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE). NVIDIA’s quantized versions (FP8 and NVFP4) are built on top of the base model and are available for research and commercial use under the same license. + +## Prerequisites + +* GPU: NVIDIA Blackwell or Hopper Architecture +* OS: Linux +* Drivers: CUDA Driver 575 or Later +* Docker with NVIDIA Container Toolkit installed +* Python3 and python3-pip (Optional, for accuracy evaluation only) + +## Models + +* FP8 model: [Llama-4-Scout-17B-16E-Instruct-FP8](https://huggingface.co/nvidia/Llama-4-Scout-17B-16E-Instruct-FP8) +* NVFP4 model: [Llama-4-Scout-17B-16E-Instruct-FP4](https://huggingface.co/nvidia/Llama-4-Scout-17B-16E-Instruct-FP4) + +Note that NVFP4 is only supported on NVIDIA Blackwell platform. + +## Deployment Steps + +### Run Docker Container + +Run the docker container using the TensorRT-LLM NVIDIA NGC image. + +```shell +docker run --rm -it \ +--ipc=host \ +--gpus all \ +-p 8000:8000 \ +-v ~/.cache:/root/.cache:rw \ +--name tensorrt_llm \ +nvcr.io/nvidia/tensorrt-llm/release:1.0.0rc6 \ +/bin/bash +``` + +Note: + +* The command mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. If the `~/.cache` directory doesn’t exist please create it using `$ mkdir ~/.cache`. +* You can mount additional directories and paths using the `-v <host_path>:<container_path>` flag if needed, such as mounting the downloaded weight paths. +* The command also maps port `8000` from the container to your host so you can access the LLM API endpoint from your host +* See the <https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags> for all the available containers. The containers published in the main branch weekly have `rcN` suffix, while the monthly release with QA tests has no `rcN` suffix. Use the `rc` release to get the latest model and feature support. + +If you want to use latest main branch, you can choose to build from source to install TensorRT-LLM, the steps refer to <https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html>. + +### Creating the TRT-LLM Server config + +We create a YAML configuration file `/tmp/config.yml` for the TensorRT-LLM Server and populate it with the following recommended performance settings. + +```shell +EXTRA_LLM_API_FILE=/tmp/config.yml + +cat << EOF > ${EXTRA_LLM_API_FILE} +enable_attention_dp: false +cuda_graph_config: + enable_padding: true + max_batch_size: 1024 +kv_cache_config: + dtype: fp8 +EOF +``` + +### Launch the TRT-LLM Server + +Below is an example command to launch the TRT-LLM server with the Llama-4-Scout-17B-16E-Instruct-FP8 model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section. + +```shell +trtllm-serve nvidia/Llama-4-Scout-17B-16E-Instruct-FP8 \ + --host 0.0.0.0 \ + --port 8000 \ + --backend pytorch \ + --max_batch_size 1024 \ + --max_num_tokens 2048 \ + --max_seq_len 2048 \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --tp_size 1 \ + --ep_size 1 \ + --trust_remote_code \ + --extra_llm_api_options ${EXTRA_LLM_API_FILE} +``` + +After the server is set up, the client can now send prompt requests to the server and receive results. + +### Configs and Parameters + +These options are used directly on the command line when you start the `trtllm-serve` process. + +#### `--tp_size` + +* **Description:** Sets the **tensor-parallel size**. This should typically match the number of GPUs you intend to use for a single model instance. + +#### `--ep_size` + +* **Description:** Sets the **expert-parallel size** for Mixture-of-Experts (MoE) models. Like `tp_size`, this should generally match the number of GPUs you're using. This setting has no effect on non-MoE models. + +#### `--kv_cache_free_gpu_memory_fraction` + +* **Description:** A value between `0.0` and `1.0` that specifies the fraction of free GPU memory to reserve for the KV cache after the model is loaded. Since memory usage can fluctuate, this buffer helps prevent out-of-memory (OOM) errors. +* **Recommendation:** If you experience OOM errors, try reducing this value to `0.7` or lower. + +#### `--backend pytorch` + +* **Description:** Tells TensorRT-LLM to use the **pytorch** backend. + +#### `--max_batch_size` + +* **Description:** The maximum number of user requests that can be grouped into a single batch for processing. + +#### `--max_num_tokens` + +* **Description:** The maximum total number of tokens (across all requests) allowed inside a single scheduled batch. + +#### `--max_seq_len` + +* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens. + +#### `--trust_remote_code` + +* **Description:** Allows TensorRT-LLM to download models and tokenizers from Hugging Face. This flag is passed directly to the Hugging Face API. + + +#### Extra LLM API Options (YAML Configuration) + +These options provide finer control over performance and are set within a YAML file passed to the `trtllm-serve` command via the `--extra_llm_api_options` argument. + +#### `kv_cache_config` + +* **Description**: A section for configuring the Key-Value (KV) cache. + +* **Options**: + + * `dtype`: Sets the data type for the KV cache. + **Default**: `"auto"` (uses the data type specified in the model checkpoint). + +#### `cuda_graph_config` + +* **Description**: A section for configuring CUDA graphs to optimize performance. + +* **Options**: + + * `enable_padding`: If `"true"`, input batches are padded to the nearest `cuda_graph_batch_size`. This can significantly improve performance. + + **Default**: `false` + + * `max_batch_size`: Sets the maximum batch size for which a CUDA graph will be created. + + **Default**: `0` + + **Recommendation**: Set this to the same value as the `--max_batch_size` command-line option. + + * `batch_sizes`: A specific list of batch sizes to create CUDA graphs for. + + **Default**: `None` + +#### `moe_config` + +* **Description**: Configuration for Mixture-of-Experts (MoE) models. + +* **Options**: + + * `backend`: The backend to use for MoE operations. + **Default**: `CUTLASS` + +#### `attention_backend` + +* **Description**: The backend to use for attention calculations. + +* **Default**: `TRTLLM` + +See the [`TorchLlmArgs` class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the `extra_llm_api_options`. + +## Testing API Endpoint + +### Basic Test + +Start a new terminal on the host to test the TensorRT-LLM server you just launched. + +You can query the health/readiness of the server using: + +```shell +curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health" +``` + +When the `Status: 200` code is returned, the server is ready for queries. Note that the very first query may take longer due to initialization and compilation. + +After the TRT-LLM server is set up and shows Application startup complete, you can send requests to the server. + +```shell +curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ + "model": "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + "prompt": "Where is New York?", + "max_tokens": 16, + "temperature": 0 +}' +``` + +Here is an example response, showing that the TRT-LLM server returns “New York is a state located in the northeastern United States. It is bordered by”, completing the input sequence. + +```json +{"id":"cmpl-bc1393d529ce485c961d9ffee5b25d72","object":"text_completion","created":1753843963,"model":"$MODEL","choices":[{"index":0,"text":" New York is a state located in the northeastern United States. It is bordered by","token_ids":null,"logprobs":null,"context_logits":null,"finish_reason":"length","stop_reason":null,"disaggregated_params":null}],"usage":{"prompt_tokens":6,"total_tokens":22,"completion_tokens":16},"prompt_token_ids":null} +``` + +### Troubleshooting Tips + +* If you encounter CUDA out-of-memory errors, try reducing `max_batch_size` or `max_seq_len`. +* Ensure your model checkpoints are compatible with the expected format. +* For performance issues, check GPU utilization with nvidia-smi while the server is running. +* If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed. +* For connection issues, make sure the server port (`8000` in this guide) is not being used by another application. + +### Running Evaluations to Verify Accuracy (Optional) + +We use the lm-eval tool to test the model’s accuracy. For more information see <https://github.com/EleutherAI/lm-evaluation-harness>. + +To run the evaluation harness exec into the running TensorRT-LLM container and install with this command: + +```shell +docker exec -it tensorrt_llm /bin/bash + +pip install lm_eval +``` + +FP8 command for GSM8K + +```shell +MODEL_PATH=nvidia/Llama-4-Scout-17B-16E-Instruct-FP8 + +lm_eval --model local-completions --tasks gsm8k --batch_size 256 --gen_kwargs temperature=0.0 --num_fewshot 5 --model_args model=${MODEL_PATH},base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=20,tokenized_requests=False --log_samples --output_path trtllm.fp8.gsm8k +``` + +Sample result in Blackwell. + +``` +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9189|± |0.0075| +| | |strict-match | 5|exact_match|↑ |0.8984|± |0.0083| +``` + +FP4 command for GSM8K + +```shell +MODEL_PATH=nvidia/Llama-4-Scout-17B-16E-Instruct-FP4 + +lm_eval --model local-completions --tasks gsm8k --batch_size 256 --gen_kwargs temperature=0.0 --num_fewshot 5 --model_args model=${MODEL_PATH},base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=20,tokenized_requests=False --log_samples --output_path trtllm.fp4.gsm8k +``` + +Sample result in Blackwell + +``` +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9075|± |0.0080| +| | |strict-match | 5|exact_match|↑ |0.8908|± |0.0086| +``` + +## Benchmarking Performance + +To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script. + +```shell +cat <<EOF > bench.sh +concurrency_list="1 2 4 8 16 32 64 128 256" +multi_round=5 +isl=1024 +osl=1024 +result_dir=/tmp/llama4_output + +for concurrency in ${concurrency_list}; do + num_prompts=$((concurrency * multi_round)) + python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model nvidia/Llama-4-Scout-17B-16E-Instruct-FP8 \ + --backend openai \ + --dataset-name "random" \ + --random-input-len ${isl} \ + --random-output-len ${osl} \ + --random-prefix-len 0 \ + --random-ids \ + --num-prompts ${num_prompts} \ + --max-concurrency ${concurrency} \ + --ignore-eos \ + --tokenize-on-client \ + --percentile-metrics "ttft,tpot,itl,e2el" +done +EOF +chmod +x bench.sh +``` + +To benchmark the FP4 model, replace `--model nvidia/Llama-4-Scout-17B-16E-Instruct-FP8` with `--model nvidia/Llama-4-Scout-17B-16E-Instruct-FP4`. + +If you want to save the results to a file add the following options. + +```shell +--save-result \ +--result-dir "${result_dir}" \ +--result-filename "concurrency_${concurrency}.json" +``` + +For more benchmarking options see <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt\_llm/serve/scripts/benchmark\_serving.py>. + +Run bench.sh to begin a serving benchmark. This will take a long time if you run all the concurrencies mentioned in the above bench.sh script. + +```shell +./bench.sh +``` + +Sample TensorRT-LLM serving benchmark output. Your results may vary due to ongoing software optimizations. + +``` +============ Serving Benchmark Result ============ +Successful requests: 16 +Benchmark duration (s): 17.66 +Total input tokens: 16384 +Total generated tokens: 16384 +Request throughput (req/s): [result] +Output token throughput (tok/s): [result] +Total Token throughput (tok/s): [result] +User throughput (tok/s): [result] +---------------Time to First Token---------------- +Mean TTFT (ms): [result] +Median TTFT (ms): [result] +P99 TTFT (ms): [result] +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): [result] +Median TPOT (ms): [result] +P99 TPOT (ms): [result] +---------------Inter-token Latency---------------- +Mean ITL (ms): [result] +Median ITL (ms): [result] +P99 ITL (ms): [result] +----------------End-to-end Latency---------------- +Mean E2EL (ms): [result] +Median E2EL (ms): [result] +P99 E2EL (ms): [result] +================================================== +``` + +### Key Metrics + +* Median Time to First Token (TTFT) + * The typical time elapsed from when a request is sent until the first output token is generated. +* Median Time Per Output Token (TPOT) + * The typical time required to generate each token *after* the first one. +* Median Inter-Token Latency (ITL) + * The typical time delay between the completion of one token and the completion of the next. +* Median End-to-End Latency (E2EL) + * The typical total time from when a request is submitted until the final token of the response is received. +* Total Token Throughput + * The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens. diff --git a/docs/source/index.rst b/docs/source/index.rst index cb04be7025..b0964ca287 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,16 @@ Welcome to TensorRT-LLM's Documentation! installation/build-from-source-linux.md +.. toctree:: + :maxdepth: 2 + :caption: Deployment Guide + :name: Deployment Guide + + deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md + deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md + deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.md + + .. toctree:: :maxdepth: 2 :caption: LLM API diff --git a/docs/source/installation/linux.md b/docs/source/installation/linux.md index ab471e8c1d..9262453b66 100644 --- a/docs/source/installation/linux.md +++ b/docs/source/installation/linux.md @@ -16,11 +16,6 @@ # Optional step: Only required for NVIDIA Blackwell GPUs and SBSA platform pip3 install torch==2.7.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 - # Optional step: Workaround for deep_gemm installation failure on SBSA platform - # The actual deep_gemm package and version should be obtained from the requirements.txt file. - pip3 install 'deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a' \ - --extra-index-url https://download.pytorch.org/whl/cu128 - sudo apt-get -y install libopenmpi-dev ``` diff --git a/docs/source/media/ad_overview.png b/docs/source/media/ad_overview.png new file mode 100644 index 0000000000..333804297a Binary files /dev/null and b/docs/source/media/ad_overview.png differ diff --git a/docs/source/performance/perf-benchmarking.md b/docs/source/performance/perf-benchmarking.md index 814e27b3d3..55caef07ba 100644 --- a/docs/source/performance/perf-benchmarking.md +++ b/docs/source/performance/perf-benchmarking.md @@ -79,7 +79,7 @@ that have been validated extensively and is the same listing as seen on the - [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) - [meta-llama/Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) - [meta-llama/Llama-3.1-405B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct) -- [mistralai/Mixtral-8x7B-v0.1-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1-Instruct) +- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) ```{tip} `trtllm-bench` can automatically download the model from Hugging Face Model Hub. @@ -236,15 +236,6 @@ The following command builds an FP8 quantized engine by specifying the engine tu trtllm-bench --model meta-llama/Llama-3.1-8B build --quantization FP8 --max_seq_len 4096 --max_batch_size 1024 --max_num_tokens 2048 ``` -- [Experimental] Build engine with target ISL/OSL for optimization: -In this experimental mode, you can provide hints to `trtllm-bench`'s tuning heuristic to optimize the engine on specific ISL and OSL targets. -Generally, the target ISL and OSL aligns with the average ISL and OSL of the dataset, but you can experiment with different values to optimize the engine using this mode. -The following command builds an FP8 quantized engine and optimizes for ISL:OSL targets of 128:128. - -```shell -trtllm-bench --model meta-llama/Llama-3.1-8B build --quantization FP8 --max_seq_len 4096 --target_isl 128 --target_osl 128 -``` - #### Parallelism Mapping Support The `trtllm-bench build` subcommand supports combinations of tensor-parallel (TP) and pipeline-parallel (PP) mappings as long as the world size (`tp_size x pp_size`) `<=` `8`. The parallelism mapping in build subcommad is controlled by `--tp_size` and `--pp_size` options. The following command builds an engine with TP2-PP2 mapping. diff --git a/docs/source/reference/multimodal-feature-support-matrix.md b/docs/source/reference/multimodal-feature-support-matrix.md new file mode 100644 index 0000000000..bb5175c9da --- /dev/null +++ b/docs/source/reference/multimodal-feature-support-matrix.md @@ -0,0 +1,13 @@ +# Multimodal Feature Support Matrix (PyTorch Backend) + +| Model | CUDA Graph | Encoder IFB | KV Cache Reuse | Chunked Prefill | +| :----------------- | :--------- | :------------------ | :------------- | :-------------- | +| Gemma 3 | Yes | Yes | No | No | +| HyperCLOVA | Yes | Yes | No | No | +| VILA | Yes | No | No | No | +| LLaVA-NeXT | Yes | Yes | No | No | +| Llama 4 | Yes | No | No | No | +| Mistral-Small-3.1 | Yes | Yes | No | No | +| Phi-4-multimodal | Yes | Yes | No | No | +| Qwen2-VL | Yes | Yes | Yes | No | +| Qwen2.5-VL | Yes | Yes | Yes | No | diff --git a/docs/source/reference/precision.md b/docs/source/reference/precision.md index 2d30c9053a..b31eff6d62 100644 --- a/docs/source/reference/precision.md +++ b/docs/source/reference/precision.md @@ -103,8 +103,7 @@ Python function, for details. This release includes examples of applying GPTQ to [GPT-NeoX](source:examples/models/core/gpt) and [LLaMA-v2](source:examples/models/core/llama), as well as an example of using AWQ with -[GPT-J](source:examples/models/contrib/gpt). Those examples are experimental implementations and -are likely to evolve in a future release. +[GPT-J](source:examples/models/contrib/gptj). ## FP8 (Hopper) diff --git a/docs/source/release-notes.md b/docs/source/release-notes.md index d0cf99c69e..5ce3e71325 100644 --- a/docs/source/release-notes.md +++ b/docs/source/release-notes.md @@ -640,7 +640,7 @@ All published functionality in the Release Notes has been fully tested and verif ### Known Issues -- On Windows, installation of TensorRT-LLM may succeed, but you might hit `OSError: exception: access violation reading 0x0000000000000000` when importing the library in Python. See [Installing on Windows](https://nvidia.github.io/TensorRT-LLM/installation/windows.html) for workarounds. +- On Windows, installation of TensorRT-LLM may succeed, but you might hit `OSError: exception: access violation reading 0x0000000000000000` when importing the library in Python. ## TensorRT-LLM Release 0.11.0 @@ -1046,7 +1046,7 @@ Refer to the {ref}`support-matrix-software` section for a list of supported mode - System prompt caching - Enabled split-k for weight-only cutlass kernels - FP8 KV cache support for XQA kernel -- New Python builder API and `trtllm-build` command (already applied to [blip2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/blip2) and [OPT](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/opt#3-build-tensorrt-engines)) +- Added Python builder API, `trtllm-build` command, and OPT support - Support `StoppingCriteria` and `LogitsProcessor` in Python generate API - FHMA support for chunked attention and paged KV cache - Performance enhancements include: diff --git a/docs/source/torch.md b/docs/source/torch.md index b04c98db1d..31684cd657 100644 --- a/docs/source/torch.md +++ b/docs/source/torch.md @@ -2,10 +2,9 @@ ```{note} Note: -This feature is currently experimental, and the related API is subjected to change in future versions. +This feature is currently in beta, and the related API is subjected to change in future versions. ``` - -To enhance the usability of the system and improve developer efficiency, TensorRT-LLM launches a new experimental backend based on PyTorch. +To enhance the usability of the system and improve developer efficiency, TensorRT-LLM launches a new backend based on PyTorch. The PyTorch backend of TensorRT-LLM is available in version 0.17 and later. You can try it via importing `tensorrt_llm._torch`. @@ -29,7 +28,6 @@ Here is a simple example to show how to use `tensorrt_llm.LLM` API with Llama mo - [Architecture Overview](./torch/arch_overview.md) - [Adding a New Model](./torch/adding_new_model.md) -- [Examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/pytorch/README.md) ## Key Components @@ -40,3 +38,7 @@ Here is a simple example to show how to use `tensorrt_llm.LLM` API with Llama mo ## Known Issues - The PyTorch backend on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for optimal support on SBSA platforms. + +## Prototype Features + +- [AutoDeploy: Seamless Model Deployment from PyTorch to TensorRT-LLM](./torch/auto_deploy/auto-deploy.md) diff --git a/docs/source/torch/auto_deploy/advanced/benchmarking_with_trtllm_bench.md b/docs/source/torch/auto_deploy/advanced/benchmarking_with_trtllm_bench.md new file mode 100644 index 0000000000..6032aacd4f --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/benchmarking_with_trtllm_bench.md @@ -0,0 +1,93 @@ +# Benchmarking with trtllm-bench + +AutoDeploy is integrated with the `trtllm-bench` performance benchmarking utility, enabling you to measure comprehensive performance metrics such as token throughput, request throughput, and latency for your AutoDeploy-optimized models. + +## Getting Started + +Before benchmarking with AutoDeploy, review the [TensorRT-LLM benchmarking guide](../../../performance/perf-benchmarking.md#running-with-the-pytorch-workflow) to familiarize yourself with the standard trtllm-bench workflow and best practices. + +## Basic Usage + +Invoke the AutoDeploy backend by specifying `--backend _autodeploy` in your `trtllm-bench` command: + +```bash +trtllm-bench \ + --model meta-llama/Llama-3.1-8B \ + throughput \ + --dataset /tmp/synthetic_128_128.txt \ + --backend _autodeploy +``` + +```{note} +As in the PyTorch workflow, AutoDeploy does not require a separate `trtllm-bench build` step. The model is automatically optimized during benchmark initialization. +``` + +## Advanced Configuration + +For more granular control over AutoDeploy's behavior during benchmarking, use the `--extra_llm_api_options` flag with a YAML configuration file: + +```bash +trtllm-bench \ + --model meta-llama/Llama-3.1-8B \ + throughput \ + --dataset /tmp/synthetic_128_128.txt \ + --backend _autodeploy \ + --extra_llm_api_options autodeploy_config.yaml +``` + +### Configuration Examples + +#### Basic Performance Configuration (`autodeploy_config.yaml`) + +```yaml +# Compilation backend +compile_backend: torch-opt + +# Runtime engine +runtime: trtllm + +# Model loading +skip_loading_weights: false + +# Fraction of free memory to use for kv-caches +free_mem_ratio: 0.8 + +# CUDA Graph optimization +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256] + +# Attention backend +attn_backend: flashinfer + +# Sequence configuration +max_batch_size: 256 +``` + +Enable multi-GPU execution by specifying `--tp n`, where `n` is the number of GPUs + +## Configuration Options Reference + +### Core Performance Settings + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `compile_backend` | `torch-compile` | Compilation backend: `torch-simple`, `torch-compile`, `torch-cudagraph`, `torch-opt` | +| `runtime` | `trtllm` | Runtime engine: `trtllm`, `demollm` | +| `free_mem_ratio` | `0.0` | Fraction of available GPU memory for KV cache (0.0-1.0) | +| `skip_loading_weights` | `false` | Skip weight loading for architecture-only benchmarks | + +### CUDA Graph Optimization + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `cuda_graph_batch_sizes` | `null` | List of batch sizes for CUDA graph creation | + +```{tip} +For optimal CUDA graph performance, specify batch sizes that match your expected workload patterns. For example: `[1, 2, 4, 8, 16, 32, 64, 128]` +``` + +## Performance Optimization Tips + +1. **Memory Management**: Set `free_mem_ratio` to 0.8-0.9 for optimal KV cache utilization +1. **Compilation Backend**: Use `torch-opt` for production workloads +1. **Attention Backend**: `flashinfer` generally provides the best performance for most models +1. **CUDA Graphs**: Enable CUDA graphs for batch sizes that match your production traffic patterns. diff --git a/docs/source/torch/auto_deploy/advanced/example_run.md b/docs/source/torch/auto_deploy/advanced/example_run.md new file mode 100644 index 0000000000..7308e7198f --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/example_run.md @@ -0,0 +1,49 @@ +# Example Run Script + +To build and run AutoDeploy example, use the `examples/auto_deploy/build_and_run_ad.py` script: + +```bash +cd examples/auto_deploy +python build_and_run_ad.py --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +``` + +You can configure your experiment with various options. Use the `-h/--help` flag to see available options: + +```bash +python build_and_run_ad.py --help +``` + +The following is a non-exhaustive list of common configuration options: + +| Configuration Key | Description | +|-------------------|-------------| +| `--model` | The HF model card or path to a HF checkpoint folder | +| `--args.model-factory` | Choose model factory implementation (`"AutoModelForCausalLM"`, ...) | +| `--args.skip-loading-weights` | Only load the architecture, not the weights | +| `--args.model-kwargs` | Extra kwargs that are being passed to the model initializer in the model factory | +| `--args.tokenizer-kwargs` | Extra kwargs that are being passed to the tokenizer initializer in the model factory | +| `--args.world-size` | The number of GPUs used for auto-sharding the model | +| `--args.runtime` | Specifies which type of Engine to use during runtime (`"demollm"` or `"trtllm"`) | +| `--args.compile-backend` | Specifies how to compile the graph at the end | +| `--args.attn-backend` | Specifies kernel implementation for attention | +| `--args.mla-backend` | Specifies implementation for multi-head latent attention | +| `--args.max-seq-len` | Maximum sequence length for inference/cache | +| `--args.max-batch-size` | Maximum dimension for statically allocated KV cache | +| `--args.attn-page-size` | Page size for attention | +| `--prompt.batch-size` | Number of queries to generate | +| `--benchmark.enabled` | Whether to run the built-in benchmark (true/false) | + +For default values and additional configuration options, refer to the `ExperimentConfig` class in `examples/auto_deploy/build_and_run_ad.py` file. + +The following is a more complete example of using the script: + +```bash +cd examples/auto_deploy +python build_and_run_ad.py \ +--model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \ +--args.world-size 2 \ +--args.runtime "demollm" \ +--args.compile-backend "torch-compile" \ +--args.attn-backend "flashinfer" \ +--benchmark.enabled True +``` diff --git a/docs/source/torch/auto_deploy/advanced/expert_configurations.md b/docs/source/torch/auto_deploy/advanced/expert_configurations.md new file mode 100644 index 0000000000..76ba2fe2b4 --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/expert_configurations.md @@ -0,0 +1,178 @@ +# Expert Configuration of LLM API + +For advanced TensorRT-LLM users, the full set of `tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs` is exposed. Use at your own risk. The argument list may diverge from the standard TRT-LLM argument list. + +- All configuration fields used by the AutoDeploy core pipeline, `InferenceOptimizer`, are exposed exclusively in `AutoDeployConfi`g in `tensorrt_llm._torch.auto_deploy.llm_args`. + Please make sure to refer to those first. +- For advanced users, the full set of `LlmArgs` in `tensorrt_llm._torch.auto_deploy.llm_args` can be used to configure the AutoDeploy `LLM` API, including runtime options. +- Note that some fields in the full `LlmArgs` + object are overlapping, duplicated, and/or _ignored_ in AutoDeploy, particularly arguments + pertaining to configuring the model itself since AutoDeploy's model ingestion+optimize pipeline + significantly differs from the default manual workflow in TensorRT-LLM. +- However, with the proper care the full `LlmArgs` + objects can be used to configure advanced runtime options in TensorRT-LLM. +- Any valid field can be simply provided as keyword argument ("`**kwargs`") to the AutoDeploy `LLM` API. + +# Expert Configuration of `build_and_run_ad.py` + +For advanced users, `build_and_run_ad.py` provides advanced configuration capabilities using a flexible argument parser powered by PyDantic Settings and OmegaConf. You can use dot notation for CLI arguments, provide multiple YAML configuration files, and utilize sophisticated configuration precedence rules to create complex deployment configurations. + +## CLI Arguments with Dot Notation + +The script supports flexible CLI argument parsing using dot notation to modify nested configurations dynamically. You can target any field in both the `ExperimentConfig` in `examples/auto_deploy/build_and_run_ad.py` and nested `AutoDeployConfig` or `LlmArgs` objects in `tensorrt_llm._torch.auto_deploy.llm_args`: + +```bash +# Configure model parameters +# NOTE: config values like num_hidden_layers are automatically resolved into the appropriate nested +# dict value ``{"args": {"model_kwargs": {"num_hidden_layers": 10}}}`` although not explicitly +# specified as CLI arg +python build_and_run_ad.py \ + --model "meta-llama/Meta-Llama-3.1-8B-Instruct" \ + --args.model-kwargs.num-hidden-layers=10 \ + --args.model-kwargs.hidden-size=2048 \ + --args.tokenizer-kwargs.padding-side=left + +# Configure runtime and backend options +python build_and_run_ad.py \ + --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \ + --args.world-size=2 \ + --args.compile-backend=torch-opt \ + --args.attn-backend=flashinfer + +# Configure prompting and benchmarking +python build_and_run_ad.py \ + --model "microsoft/phi-4" \ + --prompt.batch-size=4 \ + --prompt.sp-kwargs.max-tokens=200 \ + --prompt.sp-kwargs.temperature=0.7 \ + --benchmark.enabled=true \ + --benchmark.bs=8 \ + --benchmark.isl=1024 +``` + +## YAML Configuration Files + +Both `ExperimentConfig` and `AutoDeployConfig`/`LlmArgs` inherit from `DynamicYamlMixInForSettings`, which enables you to provide multiple YAML configuration files that are automatically deep-merged at runtime. + +Create a YAML configuration file (e.g., `my_config.yaml`): + +```yaml +# my_config.yaml +args: + model_kwargs: + num_hidden_layers: 12 + hidden_size: 1024 + world_size: 4 + compile_backend: torch-compile + attn_backend: triton + max_seq_len: 2048 + max_batch_size: 16 + transforms: + sharding: + strategy: auto + quantization: + enabled: false + +prompt: + batch_size: 8 + sp_kwargs: + max_tokens: 150 + temperature: 0.8 + top_k: 50 + +benchmark: + enabled: true + num: 20 + bs: 4 + isl: 1024 + osl: 256 +``` + +Create an additional override file (e.g., `production.yaml`): + +```yaml +# production.yaml +args: + world_size: 8 + compile_backend: torch-opt + max_batch_size: 32 + +benchmark: + enabled: false +``` + +Then use these configurations: + +```bash +# Using single YAML config +python build_and_run_ad.py \ + --model "meta-llama/Meta-Llama-3.1-8B-Instruct" \ + --yaml-configs my_config.yaml + +# Using multiple YAML configs (deep merged in order, later files have higher priority) +python build_and_run_ad.py \ + --model "meta-llama/Meta-Llama-3.1-8B-Instruct" \ + --yaml-configs my_config.yaml production.yaml + +# Targeting nested AutoDeployConfig with separate YAML +python build_and_run_ad.py \ + --model "meta-llama/Meta-Llama-3.1-8B-Instruct" \ + --yaml-configs my_config.yaml \ + --args.yaml-configs autodeploy_overrides.yaml +``` + +## Configuration Precedence and Deep Merging + +The configuration system follows a precedence order in which higher priority sources override lower priority ones: + +1. **CLI Arguments** (highest priority) - Direct command line arguments +1. **YAML Configs** - Files specified via `--yaml-configs` and `--args.yaml-configs` +1. **Default Settings** (lowest priority) - Built-in defaults from the config classes + +**Deep Merging**: Unlike simple overwriting, deep merging recursively combines nested dictionaries. For example: + +```yaml +# Base config +args: + model_kwargs: + num_hidden_layers: 10 + hidden_size: 1024 + max_seq_len: 2048 +``` + +```yaml +# Override config +args: + model_kwargs: + hidden_size: 2048 # This will override + # num_hidden_layers: 10 remains unchanged + world_size: 4 # This gets added +``` + +**Nested Config Behavior**: When using nested configurations, outer YAML configuration files become initialization settings for inner objects, giving them higher precedence: + +```bash +# The outer yaml-configs affects the entire ExperimentConfig +# The inner args.yaml-configs affects only the AutoDeployConfig +python build_and_run_ad.py \ + --model "meta-llama/Meta-Llama-3.1-8B-Instruct" \ + --yaml-configs experiment_config.yaml \ + --args.yaml-configs autodeploy_config.yaml \ + --args.world-size=8 # CLI override beats both YAML configs +``` + +## Built-in Default Configuration + +Both `AutoDeployConfig` and `LlmArgs` classes automatically load a built-in `default.yaml` configuration file that provides defaults for the AutoDeploy inference optimizer pipeline. This file is specified in the `_get_config_dict()` function in `tensorrt_llm._torch.auto_deploy.llm_args` and defines default transform configurations for graph optimization stages. + +The built-in defaults are automatically merged with your configurations at the lowest priority level, ensuring that your custom settings always override the defaults. You can inspect the current default configuration to understand the baseline transform pipeline: + +```bash +# View the default configuration +cat tensorrt_llm/_torch/auto_deploy/config/default.yaml + +# Override specific transform settings +python build_and_run_ad.py \ + --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \ + --args.transforms.export-to-gm.strict=true +``` diff --git a/docs/source/torch/auto_deploy/advanced/logging.md b/docs/source/torch/auto_deploy/advanced/logging.md new file mode 100644 index 0000000000..6de45b7567 --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/logging.md @@ -0,0 +1,14 @@ +# Logging Level + +Use the following env variable to specify the logging level of our built-in logger, ordered by +decreasing verbosity; + +```bash +AUTO_DEPLOY_LOG_LEVEL=DEBUG +AUTO_DEPLOY_LOG_LEVEL=INFO +AUTO_DEPLOY_LOG_LEVEL=WARNING +AUTO_DEPLOY_LOG_LEVEL=ERROR +AUTO_DEPLOY_LOG_LEVEL=INTERNAL_ERROR +``` + +The default log level is `INFO`. diff --git a/docs/source/torch/auto_deploy/advanced/workflow.md b/docs/source/torch/auto_deploy/advanced/workflow.md new file mode 100644 index 0000000000..191fa6f276 --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/workflow.md @@ -0,0 +1,30 @@ +### Incorporating `auto_deploy` into your own workflow + +AutoDeploy can be seamlessly integrated into existing workflows using TRT-LLM's LLM high-level API. This section provides an example for configuring and invoking AutoDeploy in custom applications. + +The following example demonstrates how to build an LLM object with AutoDeploy integration: + +``` +from tensorrt_llm._torch.auto_deploy import LLM + + +# Construct the LLM high-level interface object with autodeploy as backend +llm = LLM( + model=<HF_MODEL_CARD_OR_DIR>, + world_size=<DESIRED_WORLD_SIZE>, + compile_backend="torch-compile", + model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration + attn_backend="flashinfer", # choose between "triton" and "flashinfer" + attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton) + skip_loading_weights=False, + model_factory="AutoModelForCausalLM", # choose appropriate model factory + mla_backend="MultiHeadLatentAttention", # for models that support MLA + free_mem_ratio=0.8, # fraction of available memory for cache + simple_shard_only=False, # tensor parallelism sharding strategy + max_seq_len=<MAX_SEQ_LEN>, + max_batch_size=<MAX_BATCH_SIZE>, +) + +``` + +For more information about configuring AutoDeploy via the `LLM` API using `**kwargs`, see the AutoDeploy LLM API in `tensorrt_llm._torch.auto_deploy.llm` and the `AutoDeployConfig` class in `tensorrt_llm._torch.auto_deploy.llm_args`. diff --git a/docs/source/torch/auto_deploy/auto-deploy.md b/docs/source/torch/auto_deploy/auto-deploy.md new file mode 100644 index 0000000000..fc00c0ccc3 --- /dev/null +++ b/docs/source/torch/auto_deploy/auto-deploy.md @@ -0,0 +1,80 @@ +# AutoDeploy + +```{note} +This project is under active development and is currently in a prototype stage. The code is experimental, subject to change, and may include backward-incompatible updates. While we strive for correctness, there are no guarantees regarding functionality, stability, or reliability. +``` + +### Seamless Model Deployment from PyTorch to TensorRT-LLM + +AutoDeploy is a prototype designed to simplify and accelerate the deployment of PyTorch models, including off-the-shelf models such as those from the Hugging Face Transformers library, to TensorRT-LLM. + +![AutoDeploy overview](../../media/ad_overview.png) +<sub><em>AutoDeploy overview and relation with TensorRT-LLM's LLM API</em></sub> + +AutoDeploy provides an alternative method for deploying models using the LLM API without requiring code changes to the source model (for example, Hugging Face Transformers models) or manual implementation of inference optimizations, such as KV-caches, multi-GPU parallelism, or quantization. Instead, AutoDeploy extracts a computation graph from the source model and applies inference optimizations through a series of automated graph transformations. AutoDeploy generates an inference-optimized graph that can be directly executed in the TensorRT-LLM PyTorch runtime and leverages various runtime optimizations including in-flight batching, paging, and overlap scheduling. + +### Key Feature: + +- **Seamless Model Translation:** Automatically converts PyTorch/Hugging Face models to TensorRT-LLM without manual rewrites. +- **Unified Model Definition:** Maintain a single source of truth with your original PyTorch/Hugging Face model. +- **Optimized Inference:** Built-in transformations for sharding, quantization, KV-cache integration, MHA fusion, and CudaGraph optimization. +- **Immediate Deployment:** Day-0 support for models with continuous performance enhancements. +- **Quick Setup & Prototyping:** Lightweight pip package for easy installation with a demo environment for fast testing. + +## Get Started + +1. **Install AutoDeploy:** + +AutoDeploy is included with the TRT-LLM installation. + +```bash +sudo apt-get -y install libopenmpi-dev && pip3 install --upgrade pip setuptools && pip3 install tensorrt_llm +``` + +You can refer to [TRT-LLM installation guide](../../installation/linux.md) for more information. + +2. **Run Llama Example:** + +You are now ready to run an in-framework LLama Demo. + +The general entry point for running the AutoDeploy demo is the `build_and_run_ad.py` script, Checkpoints are loaded directly from Huggingface (HF) or a local HF-like directory: + +```bash +cd examples/auto_deploy +python build_and_run_ad.py --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +``` + +## Support Matrix + +AutoDeploy streamlines the model deployment process through an automated workflow designed for efficiency and performance. The workflow begins with a PyTorch model, which is exported using `torch.export` to generate a standard Torch graph. This graph contains core PyTorch ATen operations alongside custom attention operations, determined by the attention backend specified in the configuration. + +The exported graph then undergoes a series of automated transformations, including graph sharding, KV-cache insertion, and GEMM fusion, to optimize model performance. After these transformations, the graph is compiled using one of the supported compile backends (like `torch-opt`), followed by deploying it via the TensorRT-LLM runtime. + +- [Support Matrix](support_matrix.md) + +## Advanced Usage + +- [Example Run Script](./advanced/example_run.md) +- [Logging Level](./advanced/logging.md) +- [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md) +- [Expert Configurations](./advanced/expert_configurations.md) +- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md) + +## Roadmap + +We are actively expanding AutoDeploy to support a broader range of model architectures and inference features. + +**Upcoming Model Support:** + +- Vision-Language Models (VLMs) + +- Structured State Space Models (SSMs) and Linear Attention architectures + +**Planned Features:** + +- Low-Rank Adaptation (LoRA) + +- Speculative Decoding for accelerated generation + +To track development progress and contribute, visit our [Github Project Board](https://github.com/orgs/NVIDIA/projects/83/views/13). +We welcome community contributions, see `examples/auto_deploy/CONTRIBUTING.md` for guidelines. diff --git a/docs/source/torch/auto_deploy/support_matrix.md b/docs/source/torch/auto_deploy/support_matrix.md new file mode 100644 index 0000000000..c8780cbca1 --- /dev/null +++ b/docs/source/torch/auto_deploy/support_matrix.md @@ -0,0 +1,127 @@ +## Support Matrix + +AutoDeploy streamlines model deployment with an automated workflow designed for efficiency and performance. The workflow begins with a PyTorch model, which is exported using `torch.export` to generate a standard Torch graph. This graph contains core PyTorch ATen operations alongside custom attention operations, determined by the attention backend specified in the configuration. + +The exported graph then undergoes a series of automated transformations, including graph sharding, KV-cache insertion, and GEMM fusion, to optimize model performance. After these transformations, the graph is compiled using one of the supported compile backends (like `torch-opt`), followed by deploying it via the TRT-LLM runtime. + +### Support Models + +**Bring Your Own Model**: AutoDeploy leverages `torch.export` and dynamic graph pattern matching, enabling seamless integration for a wide variety of models without relying on hard-coded architectures. + +AutoDeploy supports Hugging Face models compatible with `AutoModelForCausalLM` and `AutoModelForImageTextToText`. +In addition, the following models have been officially validated using the default configuration: `runtime=trtllm`, `compile_backend=torch-compile`, and `attn_backend=flashinfer` + +<details> +<summary>Click to expand supported models list</summary> + +- Qwen/QwQ-32B +- Qwen/Qwen2.5-0.5B-Instruct +- Qwen/Qwen2.5-1.5B-Instruct +- Qwen/Qwen2.5-3B-Instruct +- Qwen/Qwen2.5-7B-Instruct +- Qwen/Qwen3-0.6B +- Qwen/Qwen3-235B-A22B +- Qwen/Qwen3-30B-A3B +- Qwen/Qwen3-4B +- Qwen/Qwen3-8B +- TinyLlama/TinyLlama-1.1B-Chat-v1.0 +- apple/OpenELM-1_1B-Instruct +- apple/OpenELM-270M-Instruct +- apple/OpenELM-3B-Instruct +- apple/OpenELM-450M-Instruct +- bigcode/starcoder2-15b-instruct-v0.1 +- bigcode/starcoder2-7b +- deepseek-ai/DeepSeek-Prover-V1.5-SFT +- deepseek-ai/DeepSeek-Prover-V2-7B +- deepseek-ai/DeepSeek-R1-Distill-Llama-70B +- deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B +- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B +- google/codegemma-7b-it +- google/gemma-1.1-7b-it +- google/gemma-2-27b-it +- google/gemma-2-2b-it +- google/gemma-2-9b-it +- google/gemma-2b +- google/gemma-3-1b-it +- ibm-granite/granite-3.1-2b-instruct +- ibm-granite/granite-3.1-8b-instruct +- ibm-granite/granite-3.3-2b-instruct +- ibm-granite/granite-3.3-8b-instruct +- ibm-granite/granite-guardian-3.1-2b +- ibm-granite/granite-guardian-3.2-5b +- meta-llama/CodeLlama-34b-Instruct-hf +- meta-llama/CodeLlama-7b-Instruct-hf +- meta-llama/CodeLlama-7b-Python-hf +- meta-llama/Llama-2-13b-chat-hf +- meta-llama/Llama-2-7b-chat-hf +- meta-llama/Llama-3.1-8B-Instruct +- meta-llama/Llama-3.2-1B-Instruct +- meta-llama/Llama-3.2-3B-Instruct +- meta-llama/Llama-3.3-70B-Instruct +- meta-llama/Llama-4-Maverick-17B-128E-Instruct +- meta-llama/Llama-4-Scout-17B-16E-Instruct +- microsoft/Phi-3-medium-128k-instruct +- microsoft/Phi-3-medium-4k-instruct +- microsoft/Phi-4-mini-instruct +- microsoft/Phi-4-mini-reasoning +- microsoft/Phi-4-reasoning +- microsoft/Phi-4-reasoning-plus +- microsoft/phi-4 +- mistralai/Codestral-22B-v0.1 +- mistralai/Mistral-7B-Instruct-v0.2 +- mistralai/Mistral-7B-Instruct-v0.3 +- mistralai/Mixtral-8x22B-Instruct-v0.1 +- nvidia/Llama-3.1-405B-Instruct-FP8 +- nvidia/Llama-3.1-70B-Instruct-FP8 +- nvidia/Llama-3.1-8B-Instruct-FP8 +- nvidia/Llama-3.1-Minitron-4B-Depth-Base +- nvidia/Llama-3.1-Minitron-4B-Width-Base +- nvidia/Llama-3.1-Nemotron-70B-Instruct-HF +- nvidia/Llama-3.1-Nemotron-Nano-8B-v1 +- nvidia/Llama-3_1-Nemotron-51B-Instruct +- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 +- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8 +- nvidia/Llama-3_3-Nemotron-Super-49B-v1 +- nvidia/Mistral-NeMo-Minitron-8B-Base +- perplexity-ai/r1-1776-distill-llama-70b + +</details> + +### Runtime Integrations + +AutoDeploy runs natively with the complete `TRT-LLM` stack via the `LLM` API. In addition, we provide a light-weight wrapper of the `LLM` API for onboarding and debugging new models: + +| `"runtime"` | Description | +|-------------|-------------| +| `trtllm` | A robust, production-grade runtime optimized for high-performance inference. | +| `demollm` | A lightweight runtime wrapper designed for development and testing, featuring a naive scheduler and KV-cache manager for simplified debugging and testing. | + +### Compile Backends + +AutoDeploy supports multiple backends for compiling the exported Torch graph: + +| `"compile_backend"` | Description | +|--------------------|-------------| +| `torch-simple` | Exports the graph without additional optimizations. | +| `torch-compile` | Applies `torch.compile` to the graph after all AutoDeploy transformations have been completed. | +| `torch-cudagraph` | Performs CUDA graph capture (without torch.compile). | +| `torch-opt` | Uses `torch.compile` along with CUDA Graph capture to enhance inference performance. | + +### Attention backends + +Optimize attention operations with different attention kernel implementations: + +| `"attn_backend"` | Description | +|----------------------|-------------| +| `triton` | Custom fused multi-head attention (MHA) with KV Cache kernels for efficient attention processing. | +| `flashinfer` | Uses optimized attention kernels with KV Cache from the [`flashinfer`](https://github.com/flashinfer-ai/flashinfer.git) library. | + +### Precision Support + +AutoDeploy supports models with various precision formats, including quantized checkpoints generated by [`TensorRT-Model-Optimizer`](https://github.com/NVIDIA/TensorRT-Model-Optimizer). + +**Supported precision types include:** + +- BF16 / FP16 / FP32 +- FP8 +- [NVFP4](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/) diff --git a/docs/source/torch/features/checkpoint_loading.md b/docs/source/torch/features/checkpoint_loading.md new file mode 100644 index 0000000000..2a54905f3a --- /dev/null +++ b/docs/source/torch/features/checkpoint_loading.md @@ -0,0 +1,332 @@ +# Checkpoint Loading + +The PyTorch backend provides a flexible and extensible infrastructure for loading model checkpoints from different sources and formats, such as HuggingFace (HF) or custom formats, by implementing required components like the checkpoint's weight loader, mapper, and configuration parser. + +## Table of Contents +1. [Overview](#overview) +2. [Core Components](#core-components) +3. [Built-in Checkpoint Formats](#built-in-checkpoint-formats) +4. [Using Checkpoint Loaders](#using-checkpoint-loaders) +5. [Creating Custom Checkpoint Loaders](#creating-custom-checkpoint-loaders) + +## Overview + +The checkpoint loading design is built around a plugin-like architecture that is separated into four distinct components: + +- **Checkpoint Loaders**: Orchestrates the loading process for specific formats. +- **Config Loaders**: Handles model configuration parsing and validation. +- **Weight Loaders**: Manages the actual loading of model weights from storage into memory. +- **Weight Mappers**: Maps and transforms loaded weights to the TRTLLM model's definition. + +This modular design allows for easy extension to support new checkpoint formats while maintaining backward compatibility and performance optimizations. By separating checkpoint loading into four subcomponents, users can leverage existing implementations and introduce custom, checkpoint-specific components. + +To support a new checkpoint format, you must implement all four components. +If the format shares components with an existing framework (such as HF), you only need to implement the components that differ. + +## Core Components + +### BaseCheckpointLoader + +The `BaseCheckpointLoader` is the central interface for all checkpoint loading operations. It provides a unified API regardless of the underlying checkpoint format. This interface is responsible for holding and exposing all objects required for the loading and parsing process. + +**Key Methods:** +- `load_config(checkpoint_dir, **kwargs)`: Loads and returns a `ModelConfig` object +- `load_weights(checkpoint_dir, **kwargs)`: Loads and returns a dictionary of weights +- `get_initialized_weight_mapper(model, config)`: Returns a weight mapper initialized at runtime for the model +- `cleanup()`: Releases resources and cleans up internal state + +### BaseConfigLoader + +Loads model configurations from checkpoint directories and parses them into a TRTLLM `ModelConfig`: + +```python +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader + +class CustomConfigLoader(BaseConfigLoader): + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + # Load and parse configuration from your custom format + pretrained_config = self._get_pretrained_config(checkpoint_dir, **kwargs) + + return ModelConfig(pretrained_config=pretrained_config, + ...) + + def _get_pretrained_config(self, checkpoint_dir, **kwargs): + ... + +``` + +### BaseWeightLoader + +Handles the loading of model weights from storage: + +```python +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader + +class CustomWeightLoader(BaseWeightLoader): + def load_weights(self, checkpoint_dir: str) -> dict[str, Any]: + # Load weights from your custom format + # Return a dictionary mapping parameter names to tensors + return weights_dict +``` + +### BaseWeightMapper + +Transforms weights between different naming conventions and applies model-specific transformations to the TRTLLM model object. + +## Built-in Checkpoint Formats + +### HuggingFace Format + +Currently, the HF checkpoint loader is the primary built-in format and supports: + +- **Weights loading** (`.safetensors, .bin, .pth`): Load HF-compatible weights from disk +- **Configuration parser** - Parse configuration information stored by HF into a TRTLLM `ModelConfig` object +- **Weights Mapping** - Convert HF weights into a TRTLLM-compatible representation + +## Using Checkpoint Loaders + +### Basic Usage + +There are two main approaches for using checkpoint loading objects + +The first approach is through the llm-api, as shown in the following example: + +```python +from tensorrt_llm import LLM + +hf_model_dir = "llama-models-v2/llama-v2-13b-hf" + +llm = LLM(model=hf_model_dir) +``` + +In this example, the `HfCheckpointLoader` is selected by default. + +To explicitly set the checkpoint loader, specify the required checkpoint-specific loader: + +```python +from tensorrt_llm import LLM +from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader + +hf_model_dir = "llama-models-v2/llama-v2-13b-hf" + +llm = LLM(model=hf_model_dir, + checkpoint_loader=HfCheckpointLoader()) +``` + +Similarly, to use a basic checkpoint loader with a specific subcomponent, provide the desired subcomponent as needed: + +```python +from tensorrt_llm import LLM +from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader + +hf_model_dir = "llama-models-v2/llama-v2-13b-hf" + +llm = LLM(model=hf_model_dir, + checkpoint_loader=HfCheckpointLoader(weight_loader=MyCustomWeightLoader())) +``` + +In the second approach, you can directly use the individual checkpoint loading components: + +```python +from tensorrt_llm._torch.models.checkpoints.hf.gemma3_weight_mapper import \ + Gemma3HfWeightMapper +from tensorrt_llm._torch.models.modeling_gemma3 import Gemma3ForCausalLM + +gemma3 = Gemma3ForCausalLM(model_config) +weight_mapper = Gemma3HfWeightMapper() +weight_mapper.init_model_and_config(gemma3, model_config) +gemma3.load_weights(hf_gemma3.state_dict(), weight_mapper) +``` +## Creating Custom Checkpoint Loaders + +To support a new checkpoint format, implement all four components. This section provides minimal templates for each. + +### When to Create Custom Components + +- **Complete New Format**: Implement all four components to support a new checkpoint format +- **Custom Weight Storage**: Implement only a custom weight loader if you have a unique weight storage format (such as a custom binary format or database storage) +- **Custom Configuration**: Implement only a custom config loader if your configuration format cannot be parsed by existing loaders +- **Custom Weight Mapping**: Implement only a custom weight mapper if your model has unique weight naming or transformation requirements that are checkpoint-specific + +### Step 1: Create the Checkpoint Loader + +```python +from typing import Optional +from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper +from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader + +@register_checkpoint_loader("CUSTOM_FORMAT") +class CustomCheckpointLoader(BaseCheckpointLoader): + def __init__(self, + *, + weight_loader: Optional[BaseWeightLoader] = None, + weight_mapper: Optional[BaseWeightMapper] = None, + config_loader: Optional[BaseConfigLoader] = None): + self._weight_loader = weight_loader or self.get_default_weight_loader() + self._config_loader = config_loader or self.get_default_config_loader() + self._weight_mapper = weight_mapper + self._checkpoint_format = "CUSTOM_FORMAT" # Set the checkpoint format name + + def get_default_weight_loader(self) -> BaseWeightLoader: + return CustomWeightLoader() + + def get_default_config_loader(self) -> BaseConfigLoader: + return CustomConfigLoader() +``` + +### Step 2: Create the Checkpoint Weight Loader + +```python +from typing import Any +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader +from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_weight_loader + +@register_checkpoint_weight_loader("CUSTOM_FORMAT") +class CustomWeightLoader(BaseWeightLoader): + def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: + """ + Load weights from your custom format. + + Args: + checkpoint_dir: Directory containing checkpoint files + **kwargs: Additional loading parameters + + Returns: + Dictionary mapping parameter names to tensors + """ + weights = {} # Implement your custom weight loading logic here + + # Examples: + # - Load from custom binary files + # - Load from databases + # - Load from compressed archives + # - Apply custom preprocessing + + return weights +``` + +### Step 3: Create the Checkpoint Config Loader + +```python +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader +from tensorrt_llm._torch.models.modeling_utils import register_config_loader + +@register_config_loader("CUSTOM_FORMAT") +class CustomConfigLoader(BaseConfigLoader): + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + """ + Load and parse configuration from your custom format. + + Args: + checkpoint_dir: Directory containing configuration files + **kwargs: Additional loading parameters + + Returns: + ModelConfig object containing parsed configuration + """ + # Load your custom configuration format here + # Examples: + # - Parse YAML/TOML files + # - Convert from proprietary formats + + pretrained_config = self._load_pretrained_config(checkpoint_dir, **kwargs) + + return ModelConfig( + pretrained_config=pretrained_config, + # Add other ModelConfig parameters as needed + ) + + def _load_pretrained_config(self, checkpoint_dir: str, **kwargs): + """Load the raw configuration from your custom format.""" + # Implement as needed + pass +``` + +### Step 4: Create the Checkpoint Weight Mapper + +```python +from torch import nn +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper +from tensorrt_llm._torch.models.modeling_utils import register_mapper + +@register_mapper("CUSTOM_FORMAT") +class CustomWeightMapper(BaseWeightMapper): + def __init__(self): + super().__init__() + # Define any weight transformation callbacks + self._callbacks = [ + # Add your custom weight transformation functions + # self._custom_transform_function, + ] + + def map_weights(self) -> None: + """ + Define mappings between source and target weight names. + """ + self.mapping.update({ + # Map source names to target names + # 'target_module_name': ['source_param1', 'source_param2'], + # For example: 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'] + }) + + def apply_callbacks(self, module: nn.Module, module_name: str, + module_names_breakdown: list[str], + weights: dict) -> list[dict]: + """ + Apply weight transformations for modules that require special handling. + + Args: + module: The target module + module_name: The specific module name being processed + module_names_breakdown: Module path components + weights: Source weights dictionary + + Returns: + List of transformed weight dictionaries + """ + module_weights = [] + + for new_name in self._mapping[module_name]: + # Filter weights for this specific parameter + fw = self.filter_weights( + '.'.join(module_names_breakdown + [new_name]), weights) + + # Apply transformation callbacks + for callback in self._callbacks: + fw = callback(module, new_name, fw) + + module_weights.append(fw) + + return module_weights + + def should_skip_module(self, module_name: str) -> bool: + """ + Define which modules should be skipped during loading. + """ + # Add logic to skip specific modules based on your requirements + # Examples: + # - Skip LoRA-specific modules + # - Skip temporary/auxiliary modules + + return super().should_skip_module(module_name) +``` + +Note: When creating a custom mapper, you can define either a checkpoint-format-specific mapper. For example: + +```python +@register_mapper("CUSTOM_FORMAT") +class CustomWeightMapper(BaseWeightMapper) +``` + +Alternatively, you can define a checkpoint-model-specific mapper. For example: + +```python +@register_mapper("CUSTOM_FORMAT", "Gemma3ForCausalLM") +class CustomWeightMapper(BaseWeightMapper) +``` + +By setting the model name, the registered mapper will be associated with the specific model. diff --git a/docs/source/torch/features/feature_combination_matrix.md b/docs/source/torch/features/feature_combination_matrix.md index 35a10a4959..f39a800fcd 100644 --- a/docs/source/torch/features/feature_combination_matrix.md +++ b/docs/source/torch/features/feature_combination_matrix.md @@ -7,12 +7,12 @@ | Attention Data Parallelism | Yes | Yes | --- | | | | | | | | | | | | | Disaggregated Serving | Yes | Yes | Yes | --- | | | | | | | | | | | | Chunked Prefill | Yes | Yes | Yes | Untested | --- | | | | | | | | | | -| MTP | Yes | Yes | Yes | Yes | Untested | --- | | | | | | | | | -| EAGLE-3(One Model Engine) | Yes | Yes | Yes | No | Yes | No | --- | | | | | | | | -| EAGLE-3(Two Model Engine) | NO | Yes | Yes | No | Yes | No | No | --- | | | | | | | +| MTP | Yes | Yes | Yes | Yes | Yes | --- | | | | | | | | | +| EAGLE-3(One Model Engine) | Yes | Yes | Yes | Yes | Yes | No | --- | | | | | | | | +| EAGLE-3(Two Model Engine) | NO | Yes | Yes | Yes | Yes | No | No | --- | | | | | | | | Torch Sampler | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | --- | | | | | | | TLLM C++ Sampler | Yes | Yes | Yes | Yes | Yes | No | No | No | No | --- | | | | | | KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | | -| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | | -| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | | -| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- | +| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | | +| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | | +| Guided Decoding | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- | diff --git a/docs/source/torch/features/lora.md b/docs/source/torch/features/lora.md new file mode 100644 index 0000000000..d00a27d49a --- /dev/null +++ b/docs/source/torch/features/lora.md @@ -0,0 +1,224 @@ +# LoRA (Low-Rank Adaptation) + +LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that enables adapting large language models to specific tasks without modifying the original model weights. Instead of fine-tuning all parameters, LoRA introduces small trainable rank decomposition matrices that are added to existing weights during inference. + +## Table of Contents +1. [Background](#background) +2. [Basic Usage](#basic-usage) + - [Single LoRA Adapter](#single-lora-adapter) + - [Multi-LoRA Support](#multi-lora-support) +3. [Advanced Usage](#advanced-usage) + - [LoRA with Quantization](#lora-with-quantization) + - [NeMo LoRA Format](#nemo-lora-format) + - [Cache Management](#cache-management) +4. [TRTLLM serve with LoRA](#trtllm-serve-with-lora) + - [YAML Configuration](#yaml-configuration) + - [Starting the Server](#starting-the-server) + - [Client Usage](#client-usage) +5. [TRTLLM bench with LORA](#trtllm-bench-with-lora) + - [YAML Configuration](#yaml-configuration) + - [Run trtllm-bench](#run-trtllm-bench) + +## Background + +The PyTorch backend provides LoRA support, allowing you to: +- Load and apply multiple LoRA adapters simultaneously +- Switch between different adapters for different requests +- Use LoRA with quantized models +- Support both HuggingFace and NeMo LoRA formats + +## Basic Usage + +### Single LoRA Adapter + +```python +from tensorrt_llm import LLM +from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.executor.request import LoRARequest +from tensorrt_llm.sampling_params import SamplingParams + +# Configure LoRA +lora_config = LoraConfig( + lora_dir=["/path/to/lora/adapter"], + max_lora_rank=8, + max_loras=1, + max_cpu_loras=1 +) + +# Initialize LLM with LoRA support +llm = LLM( + model="/path/to/base/model", + lora_config=lora_config +) + +# Create LoRA request +lora_request = LoRARequest("my-lora-task", 0, "/path/to/lora/adapter") + +# Generate with LoRA +prompts = ["Hello, how are you?"] +sampling_params = SamplingParams(max_tokens=50) + +outputs = llm.generate( + prompts, + sampling_params, + lora_request=[lora_request] +) +``` + +### Multi-LoRA Support + +```python +# Configure for multiple LoRA adapters +lora_config = LoraConfig( + lora_target_modules=['attn_q', 'attn_k', 'attn_v'], + max_lora_rank=8, + max_loras=4, + max_cpu_loras=8 +) + +llm = LLM(model="/path/to/base/model", lora_config=lora_config) + +# Create multiple LoRA requests +lora_req1 = LoRARequest("task-1", 0, "/path/to/adapter1") +lora_req2 = LoRARequest("task-2", 1, "/path/to/adapter2") + +prompts = [ + "Translate to French: Hello world", + "Summarize: This is a long document..." +] + +# Apply different LoRAs to different prompts +outputs = llm.generate( + prompts, + sampling_params, + lora_request=[lora_req1, lora_req2] +) +``` + +## Advanced Usage + +### LoRA with Quantization + +```python +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.mode import QuantAlgo + +# Configure quantization +quant_config = QuantConfig( + quant_algo=QuantAlgo.FP8, + kv_cache_quant_algo=QuantAlgo.FP8 +) + +# LoRA works with quantized models +llm = LLM( + model="/path/to/model", + quant_config=quant_config, + lora_config=lora_config +) +``` + +### NeMo LoRA Format + +```python +# For NeMo-format LoRA checkpoints +lora_config = LoraConfig( + lora_dir=["/path/to/nemo/lora"], + lora_ckpt_source="nemo", + max_lora_rank=8 +) + +lora_request = LoRARequest( + "nemo-task", + 0, + "/path/to/nemo/lora", + lora_ckpt_source="nemo" +) +``` + +### Cache Management + +```python +from tensorrt_llm.llmapi.llm_args import PeftCacheConfig + +# Fine-tune cache sizes +peft_cache_config = PeftCacheConfig( + host_cache_size=1024*1024*1024, # 1GB CPU cache + device_cache_percent=0.1 # 10% of free GPU memory +) + +llm = LLM( + model="/path/to/model", + lora_config=lora_config, + peft_cache_config=peft_cache_config +) +``` + +## TRTLLM serve with LoRA + +### YAML Configuration + +Create an `extra_llm_api_options.yaml` file: + +```yaml +lora_config: + lora_target_modules: ['attn_q', 'attn_k', 'attn_v'] + max_lora_rank: 8 +``` + +### Starting the Server + +```bash +python -m tensorrt_llm.commands.serve + /path/to/model \ + --extra_llm_api_options extra_llm_api_options.yaml +``` + +### Client Usage + +```python +import openai + +client = openai.OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") + +response = client.completions.create( + model="/path/to/model", + prompt="What is the capital city of France?", + max_tokens=20, + extra_body={ + "lora_request": { + "lora_name": "lora-example-0", + "lora_int_id": 0, + "lora_path": "/path/to/lora_adapter" + } + }, +) +``` + +## TRTLLM bench with LORA + +### YAML Configuration + +Create an `extra_llm_api_options.yaml` file: + +```yaml +lora_config: + lora_dir: + - /workspaces/tensorrt_llm/loras/0 + max_lora_rank: 64 + max_loras: 8 + max_cpu_loras: 8 + lora_target_modules: + - attn_q + - attn_k + - attn_v + trtllm_modules_to_hf_modules: + attn_q: q_proj + attn_k: k_proj + attn_v: v_proj +``` + +### Run trtllm-bench + +```bash +trtllm-bench --model $model_path throughput --dataset $dataset_path --extra_llm_api_options extra-llm-api-options.yaml --num_requests 64 --concurrency 16 +``` diff --git a/docs/source/torch/features/sampling.md b/docs/source/torch/features/sampling.md index 4756903968..ce164cc0ae 100644 --- a/docs/source/torch/features/sampling.md +++ b/docs/source/torch/features/sampling.md @@ -2,12 +2,11 @@ The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, stop words, bad words, penalty, context and generation logits, and log probs. -In order to use this feature, it is necessary to enable option `enable_trtllm_sampler` in the `LLM` class, and pass a `SamplingParams` object with the desired options as well. The following example prepares two identical prompts which will give different results due to the sampling parameters chosen: +The following example prepares two identical prompts which will give different results due to the sampling parameters chosen: ```python from tensorrt_llm import LLM -llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8', - enable_trtllm_sampler=True) +llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8') sampling_params = SamplingParams( temperature=1.0, top_k=8, @@ -17,4 +16,4 @@ llm.generate(["Hello, my name is", "Hello, my name is"], sampling_params) ``` -When using speculative decoders such as MTP or Eagle-3, the `enable_trtllm_sampler` option is not yet supported and therefore the subset of sampling options available is more restricted. +When using speculative decoders such as MTP or Eagle-3 the subset of sampling options available is more restricted. diff --git a/examples/auto_deploy/README.md b/examples/auto_deploy/README.md index 399d31ce36..cba226e731 100644 --- a/examples/auto_deploy/README.md +++ b/examples/auto_deploy/README.md @@ -6,7 +6,7 @@ <div align="left"> -AutoDeploy is an experimental feature in beta stage designed to simplify and accelerate the deployment of PyTorch models, including off-the-shelf models like those from Hugging Face, to TensorRT-LLM. It automates graph transformations to integrate inference optimizations such as tensor parallelism, KV-caching and quantization. AutoDeploy supports optimized in-framework deployment, minimizing the amount of manual modification needed. +AutoDeploy is a prototype feature in beta stage designed to simplify and accelerate the deployment of PyTorch models, including off-the-shelf models like those from Hugging Face, to TensorRT-LLM. It automates graph transformations to integrate inference optimizations such as tensor parallelism, KV-caching and quantization. AutoDeploy supports optimized in-framework deployment, minimizing the amount of manual modification needed. ______________________________________________________________________ @@ -450,4 +450,4 @@ the current progress in AutoDeploy and where you can help. ## Disclaimer -This project is in active development and is currently in an early (beta) stage. The code is experimental, subject to change, and may include backward-incompatible updates. While we strive for correctness, we provide no guarantees regarding functionality, stability, or reliability. Use at your own risk. +This project is in active development and is currently in an early (beta) stage. The code is in prototype, subject to change, and may include backward-incompatible updates. While we strive for correctness, we provide no guarantees regarding functionality, stability, or reliability. Use at your own risk. diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index 35879834db..42a2f927dd 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -41,10 +41,6 @@ class PromptConfig(BaseModel): "In simple words and in a single sentence, explain the concept of gravity: ", "How to fix slicing in golf? ", "Where is the capital of Iceland? ", - "How big is the universe? ", - "In simple words and in a single sentence, explain the concept of gravity: ", - "How to fix slicing in golf? ", - "Where is the capital of Iceland? ", ] ) sp_kwargs: Dict[str, Any] = Field( diff --git a/examples/constraints.txt b/examples/constraints.txt index 756d3b8fd2..4ce23b0de7 100644 --- a/examples/constraints.txt +++ b/examples/constraints.txt @@ -1,3 +1,3 @@ -tensorrt_llm==1.0.0rc6 +tensorrt_llm==1.1.0rc1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/cpp/executor/README.md b/examples/cpp/executor/README.md index fdb5b0d434..4cc9b72ad9 100644 --- a/examples/cpp/executor/README.md +++ b/examples/cpp/executor/README.md @@ -124,10 +124,10 @@ From the `examples/cpp/executor/build` folder, you can also run the `executorExa ``` ./executorExampleDisaggregated -h ``` -Note setting `TRTLLM_USE_MPI_KVCACHE=1` is required to run disaggregated executor. +Note setting `TRTLLM_USE_UCX_KVCACHE=1` is required to run disaggregated executor. For example, you can run : ``` -export TRTLLM_USE_MPI_KVCACHE=1 +export TRTLLM_USE_UCX_KVCACHE=1 mpirun -n <num_ranks> --allow-run-as-root --oversubscribe ./executorExampleDisaggregated --context_engine_dir <path_to_context_engine_dir> --context_rank_size <num_ranks_for_context> --generation_engine_dir <path_to_generation_engine_dir> --generation_rank_size <num_ranks_for_generation> --input_tokens ../inputTokens.csv diff --git a/examples/disaggregated/README.md b/examples/disaggregated/README.md index 99bd3de208..196113d987 100644 --- a/examples/disaggregated/README.md +++ b/examples/disaggregated/README.md @@ -1,38 +1,64 @@ # Disaggregated Serving -To run TensorRT-LLM in disaggregated mode, you must first launch context (prefill) and generation (decode) servers using `trtllm-serve`. +The execution method of disaggregated serving relies on the `trtllm-serve` command. Specifically, compared to the standard usage of `trtllm-serve`, serving requires running this command multiple times to separately start the router and workers (including context and generation) serving components. This document focuses on this approach and provides a detailed guide on how to use it. -## Launching disaggregated servers locally on single node +Please note that disaggregated serving is currently an experimental feature, so the usage described in this document may change in the future. -We use the `cache_transceiver_config` configuration to set up disaggregated serving, which includes the following parameters: +## Startup Procedure + +### Configuration File + +The `trtllm-serve` command supports the `extra-llm-config.yaml` parameter. In the extra LLM configuration file, the `cache_transceiver_config` field is specifically used for disaggregated service. It is mainly used to specify additional parameters required for the KV cache transmission process. ```yaml cache_transceiver_config: + # KV cache transmission backend. Valid options include `DEFAULT` (i.e., UCX), `UCX`, `NIXL`. backend: <str> + # KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance. max_tokens_in_buffer: <int> ``` -`backend` specifies the communication backend for transferring the kvCache, valid options include `DEFAULT`,`UCX`, `NIXL`, and `MPI`, the default backend is UCX. +The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below. -`max_tokens_in_buffer` defines the buffer size for kvCache transfers, it is recommended to set this value greater than or equal to the maximum ISL (Input Sequence Length) of all requests for optimal performance. +```yaml +# ctx_extra-llm-api-config.yaml -You can use multiple `trtllm-serve` commands to launch the context and generation servers that will be used -for disaggregated serving. For example, you could launch two context servers and one generation servers as follows: +# The overlap scheduler for context servers is currently disabled, as it is +# not yet supported in disaggregated context server architectures. +disable_overlap_scheduler: True +cache_transceiver_config: + backend: UCX + max_tokens_in_buffer: 2048 +``` + +```yaml +# gen_extra-llm-api-config.yaml + +cache_transceiver_config: + backend: UCX + max_tokens_in_buffer: 2048 +``` + +### Basic Usage + +For non-SLURM clusters - particularly in single-node, multi-GPU setups, it is recommended to use standard mode. In such cases, the system does not enforce limits on process creation or termination. + +Suppose we have three CUDA devices on the same machine. The first two devices are used to launch one context model each, and the third device is used to launch one generation model. In this case, the following commands need to be executed. ```bash -# Generate context_extra-llm-api-config.yml -# Overlap scheduler for context servers are disabled because it's not supported for disaggregated context servers yet -echo -e "disable_overlap_scheduler: True\ncache_transceiver_config:\n backend: UCX\n max_tokens_in_buffer: 2048" > context_extra-llm-api-config.yml - # Start context servers -CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --extra_llm_api_options ./context_extra-llm-api-config.yml &> log_ctx_0 & -CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --extra_llm_api_options ./context_extra-llm-api-config.yml &> log_ctx_1 & +CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --host localhost --port 8001 \ + --extra_llm_api_options ./ctx_extra-llm-api-config.yaml &> log_ctx_0 & -# Generate gen_extra-llm-api-config.yml -echo -e "cache_transceiver_config:\n backend: UCX\n max_tokens_in_buffer: 2048" > gen_extra-llm-api-config.yml +CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --host localhost --port 8002 \ + --extra_llm_api_options ./ctx_extra-llm-api-config.yaml &> log_ctx_1 & -# Start generation servers -CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8003 --extra_llm_api_options ./gen_extra-llm-api-config.yml &> log_gen_0 & +# Start generation server +CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --host localhost --port 8003 \ + --extra_llm_api_options ./gen_extra-llm-api-config.yaml &> log_gen_0 & ``` Once the context and generation servers are launched, you can launch the disaggregated @@ -40,11 +66,16 @@ server, which will accept requests from clients and do the orchestration between and generation servers. The disaggregated server can be launched with: ```bash +# Start proxy trtllm-serve disaggregated -c disagg_config.yaml ``` + where `disagg_config.yaml` contains information about the context and generation servers. For the current example, it would look like: + ```yaml +# disagg_config.yaml + hostname: localhost port: 8000 backend: pytorch @@ -61,13 +92,11 @@ generation_servers: Clients can then send requests to the disaggregated server at `localhost:8000`, which is an OpenAI API compatible endpoint. -## Launching disaggregated servers on SLURM clusters -Refer to [Disaggregated Inference Benchmark Scripts](./slurm/). - -## Sending requests to the disaggregated server +#### Sending requests to the disaggregated server Once the context, generation and disaggregated servers are launched, you can send requests to the disaggregated server using curl: + ```bash curl http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ @@ -78,33 +107,124 @@ curl http://localhost:8000/v1/completions \ "temperature": 0 }' -w "\n" ``` + Or using the provided client parsing the prompts from a file and sending request to the disaggregated server specified in the `disagg_config.yaml` file at the `chat` endpoint: + ``` python3 ./clients/disagg_client.py -c disagg_config.yaml -p ./clients/prompts.json -e chat ``` -## Dynamic scaling (Experimental) +### Launching disaggregated servers on SLURM clusters + +To simplify usage, TensorRT-LLM internally relies on MPI spawning processes. However, some clusters do not offer such process flexibility. In these cases, we provide the `trtllm-llmapi-launch` tool to launch all processes at once. Therefore, when using TensorRT-LLM on a Slurm cluster, please refer to the following method. + +#### Single-Node Execution + +After starting the node and entering interactive mode, you can run the following command to prevent process spawning. + +```bash +# Start context servers +CUDA_VISIBLE_DEVICES=0 trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --host localhost --port 8001 \ + --extra_llm_api_options ./ctx_extra-llm-api-config.yaml &> log_ctx_0 & + +CUDA_VISIBLE_DEVICES=1 trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --host localhost --port 8002 \ + --extra_llm_api_options ./ctx_extra-llm-api-config.yaml &> log_ctx_1 & + +# Start generation server +CUDA_VISIBLE_DEVICES=2 trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --host localhost --port 8003 \ + --extra_llm_api_options ./gen_extra-llm-api-config.yaml &> log_gen_0 & + +# Start proxy +trtllm-llmapi-launch trtllm-serve disaggregated -c disagg_config.yaml +``` + +#### Multi-Node Execution + +If the model you are running cannot fit within a single node and requires multiple nodes, +we introduce the startup method using [srun](https://slurm.schedmd.com/srun.html) to run parallel jobs. + +```bash +srun -A <account> -p <partition> -t <time> -N <num_nodes> --ntasks-per-node=<tasks_per_node> \ + --container-image=<container_image> \ + --container-mounts=<mount_paths> \ + --mpi=<mpi_type> \ + bash -c '<your_command>' +``` + +When using `srun`, the `-N` and `--ntasks-per-node` options are two critical parameters that +determine how your job is distributed across the cluster. + +- `-N <num_nodes>`: Specifies how many physical nodes to use. +- `--ntasks-per-node=<num_tasks>`: Specifies how many tasks to run on each node. + +Together, they define the total number of tasks your job will run: + +$$ +\text{Total tasks} = N \times \text{ntasks-per-node} +$$ + +Therefore, the command can be written as follows: + +```bash +# The `container_image` must have the TensorRT-LLM wheel package pre-installed. +# Once the task is successfully launched, an API service will be available externally at http://host_ip:PORT. +# Launch a context with `tp_size=8` using two 4-GPU nodes. +srun -A <account> -p <partition> -t <time> \ + -N 2 --ntasks-per-node=4 \ + --container-image=<container_image> \ + --container-mounts=<mount_paths> \ + --mpi=pmix \ + bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 8 --host 0.0.0.0 --port $PORT --extra_llm_api_options $WORK/ctx_extra-llm-api-config.yaml" + +# Launch a generation with `tp_size=4` using one 4-GPU node. +srun -A <account> -p <partition> -t <time> \ + -N 1 --ntasks-per-node=4 \ + --container-image=<container_image> \ + --container-mounts=<mount_paths> \ + --mpi=pmix \ + bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 4 --host 0.0.0.0 --port $PORT --extra_llm_api_options $WORK/gen_extra-llm-api-config.yaml" + +# Launch a proxy. +# The above-mentioned value needs to be replaced with the IP address of the host machine accessible to external +# clients, and filled in the `disagg_config.yaml` file. +srun -A <account> -p <partition> -t <time> \ + -N 1 --ntasks-per-node=1 \ + --container-image=<container_image> \ + --container-mounts=<mount_paths> \ + --mpi=pmix \ + bash -c "trtllm-llmapi-launch trtllm-serve disaggregated -c $WORK/disagg_config.yaml" +``` + +Additionally, we offer a fully executable script—please refer to [Disaggregated SLURM Scripts](./slurm/simple_example/). + + +## Dynamic scaling (Prototype) Currently, trtllm supports dynamic addition and removal of servers by leveraging ETCD. To enable this feature, you should start the context and generation servers with an additional flag ```--metadata_server_config_file``` and ```--server_role```. Before launching the context and generation servers, you should first start the ETCD server. By default, the ETCD server listens for client requests at ```localhost:2379```. + ```bash etcd ``` -After this, you can enable the dynamic scaling feature for the use case above as follows: -```bash -export TRTLLM_USE_UCX_KVCACHE=1 +After this, you can enable the dynamic scaling feature for the use case above as follows: + +```bash # Context servers -CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --server_role CONTEXT --extra_llm_api_options ./context_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_ctx_0 & -CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --server_role CONTEXT --extra_llm_api_options ./context_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_ctx_1 & +CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --server_role CONTEXT --extra_llm_api_options ./ctx_extra-llm-api-config.yaml --metadata_server_config_file ./metadata_config.yaml &> log_ctx_0 & +CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --server_role CONTEXT --extra_llm_api_options ./ctx_extra-llm-api-config.yaml --metadata_server_config_file ./metadata_config.yaml &> log_ctx_1 & # Generation servers -CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8003 --server_role GENERATION --extra_llm_api_options ./gen_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_gen_0 & +CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8003 --server_role GENERATION --extra_llm_api_options ./gen_extra-llm-api-config.yaml --metadata_server_config_file ./metadata_config.yaml &> log_gen_0 & ``` As for the disaggregated server, you should also specify the --metadata_server_config_file like the following + ```bash -trtllm-serve disaggregated -c disagg_config.yaml -m ./metadata_config.yml +trtllm-serve disaggregated -c disagg_config.yaml -m ./metadata_config.yaml ``` The metadata_config file looks like @@ -120,27 +240,29 @@ The ```hostname``` and ```port``` must match those used when starting the ETCD s ### Dynamically adding servers Users can add servers by directly launching them with trtllm-serve. For example, you can start an additional generation server as follows: + ```bash CUDA_VISIBLE_DEVICES=3 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8004 \ --server_role GENERATION \ - --extra_llm_api_options ./gen_extra-llm-api-config.yml \ - --metadata_server_config_file ./metadata_config.yml &> log_gen_0 & + --extra_llm_api_options ./gen_extra-llm-api-config.yaml \ + --metadata_server_config_file ./metadata_config.yaml &> log_gen_0 & ``` + TensorRT-LLM will automatically register any newly launched server with the ETCD server, allowing the router to send new requests to the added server. ### Dynamically removing servers When removing servers, special attention is required in the current version. You need to first remove the corresponding key from the ETCD server. After you see the log message "Server xxxx is removed," you can then safely shut down the server. This part will be improved soon. -## Launching context and generation servers using MPI (Deprecated) +## Startup Procedure with MPI Worker (Deprecated) + +In the past, we used `disaggregated_mpi_worker` to allow context nodes and generation nodes to operate within the same MPI world. However, this approach conflicts with the dynamic node addition and removal functionality. As a result, disaggregated_mpi_worker has been marked as deprecated, and the corresponding examples will be gradually removed. -One can also launch all context and generation servers using MPI. This can be done by issuing the following command: ```bash -export TRTLLM_USE_MPI_KVCACHE=1 mpirun -n <total_num_ranks> trtllm-serve disaggregated_mpi_worker -c disagg_config.yaml ``` -where `<total_num_ranks>` is the sum of `TP*PP` for all context and generation servers. For the example above, `total_num_ranks` is 3 +where `total_num_ranks` is the sum of `TP*PP` for all context and generation servers. For the example above, `total_num_ranks` is 3 since `TP` and `PP` is 1 for the two context and one generation server. The `disagg_config.yaml` file must now contain the configuration parameters of the context and generation servers. For example, @@ -174,10 +296,9 @@ generation_servers: ``` Once the context and generation servers are launched, you can again launch the disaggregated server with + ```bash trtllm-serve disaggregated -c disagg_config.yaml ``` -## Know Issues - -The MPI communication backend for kvCache transfer has been deprecated and may not be supported in the future. When using the MPI backend, the environment variable `TRTLLM_USE_MPI_KVCACHE=1` should be set to avoid conflicts between mpi4py and kvCache transfer. +The MPI communication backend for KV cache transfer has been deprecated and may not be supported in the future. When using the MPI backend, the environment variable `TRTLLM_USE_MPI_KVCACHE=1` should be set to avoid conflicts between mpi4py and KV cache transfer. diff --git a/examples/disaggregated/disagg_config.yaml b/examples/disaggregated/disagg_config.yaml index ae72c1b074..6b2b4f7123 100644 --- a/examples/disaggregated/disagg_config.yaml +++ b/examples/disaggregated/disagg_config.yaml @@ -11,7 +11,7 @@ context_servers: kv_cache_config: free_gpu_memory_fraction: 0.2 cache_transceiver_config: - backend: "default" + backend: "DEFAULT" urls: - "localhost:8001" generation_servers: @@ -19,6 +19,6 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "default" + backend: "DEFAULT" urls: - "localhost:8002" diff --git a/examples/disaggregated/slurm/README.md b/examples/disaggregated/slurm/benchmark/README.md similarity index 100% rename from examples/disaggregated/slurm/README.md rename to examples/disaggregated/slurm/benchmark/README.md diff --git a/examples/disaggregated/slurm/disaggr_torch.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm similarity index 70% rename from examples/disaggregated/slurm/disaggr_torch.slurm rename to examples/disaggregated/slurm/benchmark/disaggr_torch.slurm index f0ae69b743..9aa4712573 100644 --- a/examples/disaggregated/slurm/disaggr_torch.slurm +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm @@ -38,6 +38,34 @@ container_image=${19} mounts=${20} workdir=${21} model_dir=${22} +trtllm_repo=${23} + +echo "================= parameters =================" +echo "num_ctx_servers: ${num_ctx_servers}" +echo "ctx_tp_size: ${ctx_tp_size}" +echo "ctx_batch_size: ${ctx_batch_size}" +echo "ctx_max_num_tokens: ${ctx_max_num_tokens}" +echo "ctx_enable_attention_dp: ${ctx_enable_attention_dp}" +echo "num_gen_servers: ${num_gen_servers}" +echo "gen_tp_size: ${gen_tp_size}" +echo "gen_batch_size: ${gen_batch_size}" +echo "gen_max_num_tokens: ${gen_max_num_tokens}" +echo "gen_enable_attention_dp: ${gen_enable_attention_dp}" +echo "gen_gpu_memory_fraction: ${gen_gpu_memory_fraction}" +echo "eplb_num_slots: ${eplb_num_slots}" +echo "mtp_size: ${mtp_size}" +echo "concurrency: ${concurrency}" +echo "isl: ${isl}" +echo "osl: ${osl}" +echo "multi_round: ${multi_round}" +echo "streaming: ${streaming}" +echo "container_image: ${container_image}" +echo "mounts: ${mounts}" +echo "workdir: ${workdir}" +echo "model_dir: ${model_dir}" +echo "trtllm_repo: ${trtllm_repo}" +echo "===========================================" + ctx_max_seq_len=$((isl + 1)) gen_max_seq_len=$((isl + osl)) @@ -45,7 +73,7 @@ ctx_gpu_frac=0.75 cache_transceiver_max_num_tokens=8448 container_name=disaggr -logdir=${workdir}/benchmark-${isl}-${osl}/ +logdir=${workdir}/benchmark-${isl}-${osl} mkdir -p ${logdir} full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size} @@ -65,9 +93,13 @@ fi mkdir -p ${full_logdir} echo "Log will be saved to: ${full_logdir}" +if [ -z "${TRT_LLM_GIT_COMMIT}" ]; then + export TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown") + echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}" +fi + nsys_on="" # nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling - # start the container srun -l --container-image=${container_image} \ --container-name=${container_name} \ @@ -75,6 +107,13 @@ srun -l --container-image=${container_image} \ --mpi=pmix \ echo "Container up." +if [ -n "${trtllm_repo}" ]; then + srun --container-name=${container_name} \ + --container-mounts=${mounts} \ + --mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \ + bash -c "cd ${trtllm_repo} && echo 'Running install operation...' && pip install -e . " 2>&1 | tee ${full_logdir}/install.log +fi + # generate the yaml file srun -l --container-name=${container_name} \ --container-mounts=${mounts} \ @@ -104,11 +143,12 @@ echo "YAML file generated." hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}') echo "server host name: $hostname_value" + # start the workers srun -l --container-name=${container_name} \ --container-mounts=${mounts} \ - --mpi=pmix --overlap \ - bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log & + --mpi=pmix --overlap \ + bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log & # start the server srun -l --container-name=${container_name} \ @@ -121,7 +161,7 @@ srun -l --container-name=${container_name} \ srun -l --container-name=${container_name} \ --container-mounts=${mounts} \ --mpi=pmix --overlap -N 1 -n 1 \ - bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir}/ > ${full_logdir}/benchmark.log 2>&1 + bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} > ${full_logdir}/benchmark.log 2>&1 # try to kill the server and workers srun -l --container-name=${container_name} \ diff --git a/examples/disaggregated/slurm/gen_yaml.py b/examples/disaggregated/slurm/benchmark/gen_yaml.py similarity index 99% rename from examples/disaggregated/slurm/gen_yaml.py rename to examples/disaggregated/slurm/benchmark/gen_yaml.py index a3f8ad32ac..b3865fd700 100644 --- a/examples/disaggregated/slurm/gen_yaml.py +++ b/examples/disaggregated/slurm/benchmark/gen_yaml.py @@ -197,7 +197,7 @@ def gen_config_file(config_path: str, }, 'cache_transceiver_config': { 'max_tokens_in_buffer': cache_transceiver_max_num_tokens, - 'backend': 'default', + 'backend': 'DEFAULT', }, }, 'generation_servers': { @@ -225,7 +225,7 @@ def gen_config_file(config_path: str, }, 'cache_transceiver_config': { 'max_tokens_in_buffer': cache_transceiver_max_num_tokens, - 'backend': 'default', + 'backend': 'DEFAULT', }, 'stream_interval': 20, } diff --git a/examples/disaggregated/slurm/run_benchmark.sh b/examples/disaggregated/slurm/benchmark/run_benchmark.sh similarity index 74% rename from examples/disaggregated/slurm/run_benchmark.sh rename to examples/disaggregated/slurm/benchmark/run_benchmark.sh index 6cf7d45068..bca7657446 100644 --- a/examples/disaggregated/slurm/run_benchmark.sh +++ b/examples/disaggregated/slurm/benchmark/run_benchmark.sh @@ -16,7 +16,7 @@ isl=$1 osl=$2 multi_round=$3 model_name=$4 -concurrency=$5 +concurrency_list=$5 streaming=$6 log_path=$7 @@ -89,31 +89,31 @@ do_get_logs(){ } # run the loadgen +cp ${log_path}/output_workers.log ${log_path}/workers_start.log +for concurrency in ${concurrency_list}; do + mkdir -p ${log_path}/concurrency_${concurrency} + max_count=$((${concurrency} * ${multi_round})) + echo "Running loadgen with concurrency: ${concurrency}, max_count: ${max_count}" + python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model ${model_name} \ + --tokenizer ${model_name} \ + --dataset-name random \ + --dataset-path ${shared_gpt_path} \ + --random-input-len ${isl} \ + --random-output-len ${osl} \ + --random-prefix-len 0 \ + --num-prompts ${max_count} \ + --max-concurrency ${concurrency} \ + --host ${hostname} \ + --port ${port} \ + --ignore-eos \ + --no-test-input \ + $(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi) -mkdir -p ${log_path}/concurrency_${concurrency} -cp ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency}/workers_start.log -max_count=$((${concurrency} * ${multi_round})) -echo "Running loadgen with concurrency: ${concurrency}, max_count: ${max_count}" - -python -m tensorrt_llm.serve.scripts.benchmark_serving \ - --model ${model_name} \ - --tokenizer ${model_name} \ - --dataset-name random \ - --dataset-path ${shared_gpt_path} \ - --random-input-len ${isl} \ - --random-output-len ${osl} \ - --random-prefix-len 0 \ - --num-prompts ${max_count} \ - --max-concurrency ${concurrency} \ - --host ${hostname} \ - --port ${port} \ - --ignore-eos \ - --no-test-input \ - $(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi) - -do_get_logs ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency} -# echo "" > ${log_path}/output_workers.log -echo "done for ${concurrency} in folder ${log_path}/concurrency_${concurrency}" + do_get_logs ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency} + echo "" > ${log_path}/output_workers.log + echo "done for ${concurrency} in folder ${log_path}/concurrency_${concurrency}" +done echo "Benchmark done, gracefully shutting down server and workers..." kill -9 $(ps aux | grep '[s]tart_server.sh' | awk '{print $2}') >/dev/null 2>&1 || true diff --git a/examples/disaggregated/slurm/start_server.sh b/examples/disaggregated/slurm/benchmark/start_server.sh similarity index 100% rename from examples/disaggregated/slurm/start_server.sh rename to examples/disaggregated/slurm/benchmark/start_server.sh diff --git a/examples/disaggregated/slurm/start_worker.sh b/examples/disaggregated/slurm/benchmark/start_worker.sh similarity index 100% rename from examples/disaggregated/slurm/start_worker.sh rename to examples/disaggregated/slurm/benchmark/start_worker.sh diff --git a/examples/disaggregated/slurm/submit.sh b/examples/disaggregated/slurm/benchmark/submit.sh similarity index 90% rename from examples/disaggregated/slurm/submit.sh rename to examples/disaggregated/slurm/benchmark/submit.sh index 0498910ce1..95743f7cc1 100644 --- a/examples/disaggregated/slurm/submit.sh +++ b/examples/disaggregated/slurm/benchmark/submit.sh @@ -7,6 +7,7 @@ container_image=<container_image> mounts=<mounts> # e.g. /mnt/data:/mnt/data workdir=<workdir> # Path to disaggr_torch.slurm model_dir=<model_dir> # Path to the model checkpoint +repo_dir=<repo_dir> # Path to the repo to install TensorRT-LLM, if this is empty, the pre-installed version will be used ntasks_per_node=4 # 4 GPUs per GB200 node total_node_num=8 @@ -31,6 +32,7 @@ args=( $mounts $workdir $model_dir + $repo_dir ) # This command starts a job with 8 nodes, 32 GPUs in total. diff --git a/examples/disaggregated/slurm/simple_example/ctx_extra-llm-api-config.yaml b/examples/disaggregated/slurm/simple_example/ctx_extra-llm-api-config.yaml new file mode 100644 index 0000000000..ca3cf4cb2a --- /dev/null +++ b/examples/disaggregated/slurm/simple_example/ctx_extra-llm-api-config.yaml @@ -0,0 +1,6 @@ +# The overlap scheduler for context servers is currently disabled, as it is +# not yet supported in disaggregated context server architectures. +disable_overlap_scheduler: True +cache_transceiver_config: + backend: UCX + max_tokens_in_buffer: 2048 diff --git a/examples/disaggregated/slurm/simple_example/disagg_config.yaml b/examples/disaggregated/slurm/simple_example/disagg_config.yaml new file mode 100644 index 0000000000..8e4eb39240 --- /dev/null +++ b/examples/disaggregated/slurm/simple_example/disagg_config.yaml @@ -0,0 +1,12 @@ +# Please replace `ctx_hostname` and `gen_hostname` with the actual addresses. +hostname: localhost +port: 8000 +backend: pytorch +context_servers: + num_instances: 1 + urls: + - "ctx_hostname:8001" +generation_servers: + num_instances: 1 + urls: + - "gen_hostname:8002" diff --git a/examples/disaggregated/slurm/simple_example/gen_extra-llm-api-config.yaml b/examples/disaggregated/slurm/simple_example/gen_extra-llm-api-config.yaml new file mode 100644 index 0000000000..9daa30d6bc --- /dev/null +++ b/examples/disaggregated/slurm/simple_example/gen_extra-llm-api-config.yaml @@ -0,0 +1,3 @@ +cache_transceiver_config: + backend: UCX + max_tokens_in_buffer: 2048 diff --git a/examples/disaggregated/slurm/simple_example/launch.slurm b/examples/disaggregated/slurm/simple_example/launch.slurm new file mode 100644 index 0000000000..6afd40d7fd --- /dev/null +++ b/examples/disaggregated/slurm/simple_example/launch.slurm @@ -0,0 +1,36 @@ +#!/bin/bash +#SBATCH --partition=${partition} +#SBATCH --account=${account} +#SBATCH --job-name=${job_name} +#SBATCH --time=02:00:00 + +container_image="" +mount_paths="" +work_path="" +ctx_port=8001 +gen_port=8002 + +# The `container_image` must have the TensorRT-LLM wheel package pre-installed. +# Once the task is successfully launched, an API service will be available externally at http://host_ip:PORT. +# Launch a context with `tp_size=8` using two 4-GPU nodes. +srun --container-image=${container_image} \ + --container-mounts=${mount_paths} \ + -N 2 --ntasks-per-node=4 \ + --mpi=pmix \ + bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 8 --host 0.0.0.0 --port ${ctx_port} --extra_llm_api_options ${work_path}/ctx_extra-llm-api-config.yaml" & + +# Launch a generation with `tp_size=4` using one 4-GPU node. +srun --container-image=${container_image} \ + --container-mounts=${mount_paths} \ + -N 1 --ntasks-per-node=4 \ + --mpi=pmix \ + bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 8 --host 0.0.0.0 --port ${gen_port} --extra_llm_api_options ${work_path}/gen_extra-llm-api-config.yaml" & + +# Launch a proxy. +# The above-mentioned value needs to be replaced with the IP address of the host machine accessible to external +# clients, and filled in the `disagg_config.yaml` file. +srun --container-image=${container_image} \ + --container-mounts=${mount_paths} \ + -N 1 --ntasks-per-node=1 \ + --mpi=pmix \ + bash -c "trtllm-llmapi-launch trtllm-serve disaggregated -c ${work_path}/disagg_config.yaml" diff --git a/examples/eagle/README.md b/examples/eagle/README.md index 637223afb9..0b103ca40e 100644 --- a/examples/eagle/README.md +++ b/examples/eagle/README.md @@ -98,7 +98,6 @@ To run non-greedy sampling and use typical acceptance, set `--eagle_posterior_th `--temperature` can be specified as well. When no `--eagle_posterior_threshold` is specified or `--temperature=0.0` is set, greedy sampling is used. #### Run EAGLE-2 -**EAGLE-2 is still under the experimental stage.** EAGLE-2 can be enabled with 2 runtime flags (`--eagle_use_dynamic_tree` and `--eagle_dynamic_tree_max_top_k=N`). The same engine can be used for EAGLE-1 and EAGLE-2. Eagle choices must not be set in case of EAGLE-2. EAGLE-2 will generate the tree corresponding to choices dynamically in the runtime. For more details, please refer to [EAGLE-2 paper](https://arxiv.org/pdf/2406.16858). diff --git a/examples/llm-api/llm_multilora.py b/examples/llm-api/llm_multilora.py index 60795b6c60..0c6fa4f541 100644 --- a/examples/llm-api/llm_multilora.py +++ b/examples/llm-api/llm_multilora.py @@ -5,7 +5,7 @@ from huggingface_hub import snapshot_download from tensorrt_llm import LLM from tensorrt_llm.executor import LoRARequest -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_helper import LoraConfig def main(): diff --git a/examples/llm-api/out_of_tree_example/readme.md b/examples/llm-api/out_of_tree_example/readme.md new file mode 100644 index 0000000000..1b26ea3cd6 --- /dev/null +++ b/examples/llm-api/out_of_tree_example/readme.md @@ -0,0 +1,52 @@ +# Out-of-tree Model Development +The file `modeling_opt.py` shows an example of how a custom model can be defined using TRT-LLM APIs without modifying the source code of TRT-LLM. + +The file `main.py` shows how to run inference for such custom models using the LLM API. + + +## Out-of-tree Multimodal Models + +For multimodal models, TRT-LLM provides `quickstart_multimodal.py` to quickly run a multimodal model that is defined within TRT-LLM. `trtllm-bench` can be used for benchmarking such models. +However, the following sections describe how to use those tools for out-of-tree models. + +### Pre-requisite +To use an out-of-tree model with the quickstart example and trtllm-bench, you need to prepare the model definition files similar to a python module. +Consider the following file structure as an example: +``` +modeling_custom_phi +|-- __init__.py +|-- configuration.py +|-- modeling_custom_phi.py +|-- encoder + |-- __init__.py + |-- configuration.py + |-- modeling_encoder.py +```` +The files `__init__.py` should be populated with the right imports for the custom model. For example, the `modeling_custom_phi/__init__.py` can contain something like: +``` +from .modeling_custom_phi import MyVLMForConditionalGeneration +from . import encoder +``` + +### Quickstart Example + +Once the model definition files are prepared as a python module (as described above), you can use the `--custom_module_dirs` flag in `quickstart_multimodal.py` to load your model and run inference. + +``` +python3 quickstart_multimodal.py --model_dir ./model_ckpt --modality image --max_tokens 10 --prompt "Describe the image." --media ./demo_lower.png --image_format pil --custom_module_dirs ../modeling_custom_phi +``` + +### Benchmarking + +Similar to the quickstart example, you can use the same CLI argument with `trtllm-bench` to benchmark a custom model. + +Prepare the dataset: +``` +python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl +``` + + +Run the benchmark: +``` +trtllm-bench --model ./model_ckpt --model_path ./model_ckpt throughput --dataset mm_data.jsonl --backend pytorch --num_requests 100 --max_batch_size 4 --modality image --streaming --custom_module_dirs ../modeling_custom_phi +``` diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 13740f3d3c..61240b496d 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -53,7 +53,7 @@ def add_llm_args(parser): default='CUTLASS', choices=[ 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', - 'DEEPGEMM', 'CUTEDSL' + 'DEEPGEMM', 'CUTEDSL', 'TRITON' ]) parser.add_argument('--enable_attention_dp', default=False, @@ -65,9 +65,9 @@ def add_llm_args(parser): parser.add_argument('--attention_dp_batching_wait_iters', type=int, default=0) - parser.add_argument('--enable_trtllm_sampler', - default=False, - action='store_true') + parser.add_argument('--sampler_type', + default="auto", + choices=["auto", "TorchSampler", "TRTLLMSampler"]) parser.add_argument('--tp_size', type=int, default=1) parser.add_argument('--pp_size', type=int, default=1) parser.add_argument('--moe_ep_size', type=int, default=-1) @@ -108,6 +108,9 @@ def add_llm_args(parser): default=False, action='store_true', help='Use piecewise CUDA graph to optimize the model') + parser.add_argument('--apply_chat_template', + default=False, + action='store_true') # Sampling parser.add_argument("--max_tokens", type=int, default=64) @@ -227,7 +230,7 @@ def setup_llm(args, **kwargs): args.use_piecewise_cuda_graph) if args.use_torch_compile else None, moe_config=MoeConfig(backend=args.moe_backend), - enable_trtllm_sampler=args.enable_trtllm_sampler, + sampler_type=args.sampler_type, max_seq_len=args.max_seq_len, max_batch_size=args.max_batch_size, max_num_tokens=args.max_num_tokens, @@ -243,8 +246,7 @@ def setup_llm(args, **kwargs): trust_remote_code=args.trust_remote_code, gather_generation_logits=args.return_generation_logits, max_beam_width=args.max_beam_width, - **kwargs, - ) + **kwargs) use_beam_search = args.max_beam_width > 1 best_of = args.best_of or args.n @@ -274,6 +276,15 @@ def main(): prompts = args.prompt if args.prompt else example_prompts llm, sampling_params = setup_llm(args) + new_prompts = [] + if args.apply_chat_template: + for prompt in prompts: + messages = [{"role": "user", "content": f"{prompt}"}] + new_prompts.append( + llm.tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True)) + prompts = new_prompts outputs = llm.generate(prompts, sampling_params) for i, output in enumerate(outputs): diff --git a/examples/llm-api/quickstart_multimodal.py b/examples/llm-api/quickstart_multimodal.py index fc18671ee2..25401e1c95 100644 --- a/examples/llm-api/quickstart_multimodal.py +++ b/examples/llm-api/quickstart_multimodal.py @@ -4,8 +4,9 @@ import os from quickstart_advanced import add_llm_args, setup_llm -from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS, - default_multimodal_input_loader) +from tensorrt_llm.inputs import default_multimodal_input_loader +from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY +from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir example_medias_and_prompts = { "image": { @@ -79,10 +80,11 @@ example_medias_and_prompts = { def add_multimodal_args(parser): - parser.add_argument("--model_type", - type=str, - choices=ALL_SUPPORTED_MULTIMODAL_MODELS, - help="Model type.") + parser.add_argument( + "--model_type", + type=str, + choices=MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(), + help="Model type as specified in the HuggingFace model config.") parser.add_argument("--modality", type=str, choices=[ @@ -90,7 +92,7 @@ def add_multimodal_args(parser): "multiple_image", "mixture_text_image" ], default="image", - help="Media type.") + help="Media type being used for inference.") parser.add_argument("--media", type=str, nargs="+", @@ -108,6 +110,18 @@ def add_multimodal_args(parser): type=str, default="cpu", help="The device to have the input on.") + parser.add_argument( + "--custom_module_dirs", + type=str, + nargs="+", + default=None, + help= + ("Paths to an out-of-tree model directory which should be imported." + " This is useful to load a custom model. The directory should have a structure like:" + " <model_name>" + " ├── __init__.py" + " ├── <model_name>.py" + " └── <sub_dirs>")) return parser @@ -140,6 +154,15 @@ def parse_arguments(): def main(): args = parse_arguments() + if args.custom_module_dirs is not None: + for custom_module_dir in args.custom_module_dirs: + try: + import_custom_module_from_dir(custom_module_dir) + except Exception as e: + print( + f"Failed to import custom module from {custom_module_dir}: {e}" + ) + raise e lora_config = None if args.load_lora: @@ -159,8 +182,11 @@ def main(): model_type = args.model_type else: model_type = json.load( - open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type'] - assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}" + open(os.path.join(str(llm._hf_model_dir), + 'config.json')))['model_type'] + assert model_type in MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(), \ + f"Unsupported model_type: {model_type} found!\n" \ + f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}" # set prompts and media to example prompts and images if they are not provided if args.prompt is None: @@ -168,7 +194,7 @@ def main(): if args.media is None: args.media = example_medias_and_prompts[args.modality]["media"] inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer, - model_dir=llm._hf_model_dir, + model_dir=str(llm._hf_model_dir), model_type=model_type, modality=args.modality, prompts=args.prompt, diff --git a/examples/llm-api/star_attention.py b/examples/llm-api/star_attention.py index e6071054fe..367f7cc843 100644 --- a/examples/llm-api/star_attention.py +++ b/examples/llm-api/star_attention.py @@ -7,6 +7,7 @@ from difflib import SequenceMatcher import torch from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.mapping import CpType from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig @@ -59,7 +60,7 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False): kv_cache_quant_algo=QuantAlgo.FP8 if fp8_kv_cache else None) if fp8 else QuantConfig() cp_config = { - "cp_type": "star_attention", + "cp_type": CpType.STAR, "cp_anchor_size": args.sa_anchor_size, "block_size": args.sa_block_size } diff --git a/examples/models/core/deepseek_v3/README.md b/examples/models/core/deepseek_v3/README.md index 3f05358805..b15d078519 100644 --- a/examples/models/core/deepseek_v3/README.md +++ b/examples/models/core/deepseek_v3/README.md @@ -30,7 +30,7 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/ - [trtllm-serve](#trtllm-serve) - [Disaggregated Serving](#disaggregated-serving) - [Dynamo](#dynamo) - - [tensorrtllm\_backend for triton inference server (Experimental)](#tensorrtllm_backend-for-triton-inference-server-experimental) + - [tensorrtllm\_backend for triton inference server (Prototype)](#tensorrtllm_backend-for-triton-inference-server-prototype) - [Advanced Usages](#advanced-usages) - [Multi-node](#multi-node) - [mpirun](#mpirun) @@ -392,8 +392,8 @@ settings for your specific use case. NVIDIA Dynamo is a high-throughput low-latency inference framework designed for serving generative AI and reasoning models in multi-node distributed environments. Dynamo supports TensorRT-LLM as one of its inference engine. For details on how to use TensorRT-LLM with Dynamo please refer to [LLM Deployment Examples using TensorRT-LLM](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/README.md) -### tensorrtllm_backend for triton inference server (Experimental) -To serve the model using [tensorrtllm_backend](https://github.com/triton-inference-server/tensorrtllm_backend.git), make sure the version is v0.19+ in which the pytorch path is added as an experimental feature. +### tensorrtllm_backend for triton inference server (Prototype) +To serve the model using [tensorrtllm_backend](https://github.com/triton-inference-server/tensorrtllm_backend.git), make sure the version is v0.19+ in which the pytorch path is added as a prototype feature. The model configuration file is located at https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/all_models/llmapi/tensorrt_llm/1/model.yaml @@ -786,7 +786,7 @@ The converted checkpoint could be used as `<YOUR_MODEL_DIR>` and consumed by oth KV cache reuse is supported for MLA on SM90 and SM100. It is enabled by default. Due to extra operations like memcpy and GEMMs, GPU memory consumption may be higher and the E2E performance may have regression in some cases. Users could pass `KvCacheConfig(enable_block_reuse=False)` to LLM API to disable it. ### Chunked Prefill -Chunked Prefill is supported for MLA only on SM100 currently. You should add `--enable_chunked_prefill` to enable it. The GPU memory consumption is highly correlated with `max_num_tokens` and `max_batch_size`. If encountering out-of-memory errors, you may make these values smaller. (`max_num_tokens` must be divisible by kv cache's `tokens_per_block`) +Chunked Prefill is supported for MLA only on SM90 and SM100 currently. You should add `--enable_chunked_prefill` to enable it. The GPU memory consumption is highly correlated with `max_num_tokens` and `max_batch_size`. If encountering out-of-memory errors, you may make these values smaller. (`max_num_tokens` must be divisible by kv cache's `tokens_per_block`) More specifically, we can imitate what we did in the [Quick Start](#quick-start): diff --git a/examples/models/core/enc_dec/convert_checkpoint.py b/examples/models/core/enc_dec/convert_checkpoint.py index 577e89e941..9c1951975b 100755 --- a/examples/models/core/enc_dec/convert_checkpoint.py +++ b/examples/models/core/enc_dec/convert_checkpoint.py @@ -14,10 +14,11 @@ import safetensors from helper import (convert_weight_to_dtype, fairseq_sin_pos_embedding, fuse_qkv_one_layer, reshape, split) from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration, - MBartForConditionalGeneration, + MBartForConditionalGeneration, NougatProcessor, Pix2StructForConditionalGeneration, T5ForConditionalGeneration, VisionEncoderDecoderModel) +from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, MLPType) from tensorrt_llm.layers import LanguageAdapterConfig @@ -30,6 +31,9 @@ layernorm_type_map = {i.name: i.value for i in LayerNormType} layernorm_position_map = {i.name: i.value for i in LayerNormPositionType} mlp_type_map = {i.name: i.value for i in MLPType} +# Constants for specific model configurations +ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS = 20000 + def copy_args_to_component_config(component_config, args): for arg in vars(args): @@ -619,14 +623,19 @@ def parse_bart_config(args, hf_model): config = configparser.ConfigParser() config['decoder'] = dict() - for key, val in hf_model.model.decoder.config.to_dict().items(): - config["decoder"][key] = f"{val}" + if args.eclair_radio: + for key, val in hf_model.config.to_dict().items(): + config["decoder"][key] = f"{val}" + else: + for key, val in hf_model.model.decoder.config.to_dict().items(): + config["decoder"][key] = f"{val}" config["decoder"]["q_scaling"] = '1' config["decoder"]["rescale_before_lm_head"] = str(False) config['decoder']['has_model_final_layernorm'] = str( - args.nougat or isinstance(hf_model, MBartForConditionalGeneration)) + args.nougat or args.eclair_radio + or isinstance(hf_model, MBartForConditionalGeneration)) - if args.nougat: + if args.nougat or args.eclair_radio: # These flags are true for mbart decoders, but missing in HF config config['decoder']['normalize_before'] = str(True) config['decoder']['normalize_embeddings'] = str(True) @@ -763,10 +772,14 @@ def parse_bart_config(args, hf_model): return component_config encoder_config = None - if not args.nougat: + if not (args.nougat or args.eclair_radio): encoder_config = parse_bart_config_by_component(config, "encoder", args) decoder_config = parse_bart_config_by_component(config, "decoder", args) + # Override n_positions for eclair_radio model + if args.eclair_radio: + decoder_config.n_positions = ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS + return encoder_config, decoder_config @@ -952,11 +965,22 @@ def convert_bart_weights_to_tllm_safetensors(config, component, params): (hidden_size * 3 // mapping.tp_size))) if component == 'decoder': + import torch + lm_head_weights = params['lm_head.weight'].clone().detach() + vocab_size = config.vocab_size + if params['lm_head.weight'].shape[0] % mapping.tp_size != 0: + vocab_size_padded = pad_vocab_size(config.vocab_size, + mapping.tp_size) + pad_width = vocab_size_padded - config.vocab_size + + lm_head_weights = torch.nn.functional.pad(lm_head_weights, + (0, 0, 0, pad_width), + 'constant', + value=0) + vocab_size = vocab_size_padded weights['lm_head.weight'] = reshape( - split(params['lm_head.weight'], - mapping.tp_size, - mapping.tp_rank, - dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) + split(lm_head_weights, mapping.tp_size, mapping.tp_rank, dim=0), + (vocab_size // mapping.tp_size, hidden_size)) if config.has_model_final_layernorm: weights['transformer.ln_f.weight'] = params[ @@ -1479,6 +1503,113 @@ def get_model(args): if args.nougat: model = VisionEncoderDecoderModel.from_pretrained(args.model_dir) model = model.get_decoder() + elif args.eclair_radio: + import torch + + class RadioWithNeck(torch.nn.Module): + + def __init__(self): + super().__init__() + + self.model_encoder = torch.hub.load("NVlabs/RADIO", + "radio_model", + version="radio_v2.5-h") + self.model_encoder.summary_idxs = torch.tensor(4) + + self.conv1 = torch.nn.Conv1d(1280, 1024, 1) + self.layer_norm1 = torch.nn.LayerNorm( + 1024, eps=1e-6, elementwise_affine=True) + self.conv2 = torch.nn.Conv2d(1024, + 1024, + kernel_size=(1, 4), + stride=(1, 4), + padding=0, + bias=False) + self.layer_norm2 = torch.nn.LayerNorm( + 1024, eps=1e-6, elementwise_affine=True) + + def forward(self, pixel_values): + _, feature = self.model_encoder(pixel_values) + output = self.conv1(feature.permute(0, 2, + 1)).permute(0, 2, 1) + output = self.layer_norm1(output).permute(0, 2, 1) + + b, d, _ = output.shape + h = pixel_values.shape[-2] // 16 + w = pixel_values.shape[-1] // 16 + output = self.conv2(output.reshape(b, d, h, w)) + output = output.flatten(-2, -1).permute(0, 2, 1) + output = self.layer_norm2(output) + return output + + def get_processor(): + processor = NougatProcessor.from_pretrained( + "facebook/nougat-base") + + special_tokens = { + "output_plain_index": "<output_plain>", + "output_markdown_index": "<output_markdown>", + "output_no_text_index": "<output_no_text>", + "output_ocr_index": "<output_ocr>", + "predict_bbox_index": "<predict_bbox>", + "no_bbox_index": "<no_bbox>", + "bbox_start_index": "<bbox>", # not used but can keep + # "bbox_end_index": "</bbox>", # not used but can keep + "no_class_index": "<no_classes>", + "predict_classes_index": "<predict_classes>", + } + for key, special_t in special_tokens.items(): + processor.tokenizer.add_special_tokens( + {"additional_special_tokens": [special_t]}) + setattr(processor.tokenizer, key, + processor.tokenizer.encode(special_t)[1]) + + # Add regular tokens for boxes + processor.tokenizer.add_tokens( + [f"<x_{x_i}>" for x_i in range(1024)]) + processor.tokenizer.add_tokens( + [f"<y_{y_i}>" for y_i in range(1280)]) + # Add regular tokens for classes + #"<class_{class_i}>" + possible_classes = [ + "Text", "Title", "Section-header", "List-item", "TOC", + "Bibliography", "Footnote", "Page-header", "Page-footer", + "Picture", "Formula", "Page-number", "Table", "Caption" + ] + processor.tokenizer.add_tokens( + [f"<class_{cls}>" for cls in possible_classes]) + return processor + + processor = get_processor() + model = VisionEncoderDecoderModel.from_pretrained( + "facebook/nougat-base") + model.encoder = RadioWithNeck() + model.decoder.resize_token_embeddings(len(processor.tokenizer), + pad_to_multiple_of=64) + model.config.decoder_start_token_id = processor.tokenizer.eos_token_id # 2 + model.config.pad_token_id = processor.tokenizer.pad_token_id # 1 + from transformers.models.mbart.modeling_mbart import \ + MBartLearnedPositionalEmbedding + _, d_model = model.device, model.config.decoder.d_model + + with torch.inference_mode(): + # Inspect checkpoint shapes + safetensors.torch.load_model(model, + os.path.join( + args.model_dir, + "model.safetensors"), + strict=False) + model.decoder.model.decoder.embed_positions = MBartLearnedPositionalEmbedding( + ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS, d_model) + model.decoder.model.decoder.embed_positions.weight.data.zero_() + model.decoder.model.decoder.embed_positions.weight.requires_grad_( + True) + model.decoder.lm_head.weight = model.decoder.get_input_embeddings( + ).weight + + model.eval() + model = model.get_decoder() + else: model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir) elif args.model_type == "pix2struct": @@ -1522,14 +1653,23 @@ def convert_checkpoint(args): quant_algo = None model_type = args.model_type if args.model_type != "blip2" else "t5" - encoder_config, decoder_config = globals()[f'parse_{model_type}_config']( - args, model) + parse_config_mapper = { + 't5': parse_t5_config, + 'pix2struct': parse_pix2struct_config, + 'blip2': parse_t5_config, # blip2 uses t5 config parser + 'language_adapter': parse_language_adapter_config, + 'nmt': parse_nmt_config, + 'bart': parse_bart_config, + } + encoder_config, decoder_config = parse_config_mapper[model_type](args, + model) additional_settings = ["gated_act"] if model_type == 'language_adapter': additional_settings += ["residual_scaling", "language_adapter_config"] - if not args.nougat and args.model_type != "pix2struct": + if not (args.nougat + or args.eclair_radio) and args.model_type != "pix2struct": tllm_encoder_config = { 'architecture': "EncoderModel", 'dtype': args.dtype, @@ -1664,7 +1804,8 @@ def convert_checkpoint(args): decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding if args.workers == 1: - if not args.nougat and args.model_type != "pix2struct": + if not (args.nougat + or args.eclair_radio) and args.model_type != "pix2struct": convert(0, world_size, args, tllm_encoder_config, encoder_convert_args, encoder_saved_dir) convert(0, world_size, args, tllm_decoder_config, decoder_convert_args, @@ -1674,7 +1815,8 @@ def convert_checkpoint(args): args.workers = world_size LOGGER.info(f'Convert checkpoint using {args.workers} workers.') import torch.multiprocessing as mp - if not args.nougat and args.model_type != "pix2struct": + if not (args.nougat + or args.eclair_radio) and args.model_type != "pix2struct": mp.spawn(convert, nprocs=args.workers, args=(world_size, args, tllm_encoder_config, @@ -1736,6 +1878,9 @@ if __name__ == "__main__": parser.add_argument("--nougat", action="store_true", help="Model which uses vision encoder + mbart decoder") + parser.add_argument("--eclair_radio", + action="store_true", + help="Model which uses vision encoder + mbart decoder") parser.add_argument("--verbose", action="store_true", help="Provide verbose messages") diff --git a/examples/models/core/gpt_oss/README.md b/examples/models/core/gpt_oss/README.md new file mode 100644 index 0000000000..cda0086efe --- /dev/null +++ b/examples/models/core/gpt_oss/README.md @@ -0,0 +1,149 @@ +# GPT-OSS + +## Overview + +GPT-OSS is a reasoning model with MoE weights quantized with mxfp4. All the other weights are in bf16. + +## MoE Support Matrix + +In MoE, the weights are pre-quantized to mxfp4. The activation can be in either bf16 (Hopper) or mxfp8 (Blackwell), with similar accuracy. FP8 activation with per-tensor scaling factor has limited support. Note that the per-tensor scaling factor needs to be calculated dynamically during inference with the official mxfp4 checkpoints, which may negatively impact perf. The configs in **bold** are the recommended configs for the official checkpoints. + +| device | Activation | Weight | Supported moe_backend | MMA| +|----------|----------|----------|----------|----------| +| Hopper | **bf16** | mxfp4 | **TRITON**, CUTLASS | simulated mxfp4, HGMMA | +| Hopper | fp8 | mxfp4 | CUTLASS (not enabled) | simulated mxfp4, QGMMA | +| Blackwell | **mxfp8** | mxfp4 | **CUTLASS, TRTLLM** | UTCQMMA | +| Blackwell | fp8 | mxfp4 | CUTLASS, TRTLLM | UTCQMMA | +| Blackwell | fp8 | mxfp4 | TRITON (experimental) | NA | +| Blackwell | bf16 | mxfp4 | TRTLLM | simulated mxfp4, UTCHMMA | + + +| moe_backend | TP | EP | AlltoAll | +|----------|----------|----------|----------| +| CUTLASS | yes | yes | yes | +| TRTLLM | yes | yes | no | +| TRITON | no | yes | no | + +For best performance, use the `TRITON` moe_backend on Hopper for both latency and throughput cases. Use `CUTLASS` for throughput cases and `TRTLLM` for latency cases on Blackwell. + +## Harmony Examples + +### Function Calling + +OpenAI MoE models support function calling. Here is an example based on [XGrammar](https://github.com/mlc-ai/xgrammar)'s structural tag. + +First, launch a server with XGrammar enabled: + +```bash +cat > ./extra_llm_api_options.yaml <<EOF +guided_decoding_backend: xgrammar +EOF + +trtllm-serve <model> \ + --backend pytorch \ + --extra_llm_api_options extra_llm_api_options.yaml +``` + +Run the [openai_chat_client_function_calling.py](./openai_chat_client_function_calling.py) script, which queries the LLM server in two steps: + +1. **First step:** + - The client provides function definitions and a user prompt to the LLM server + - Instead of answering the prompt directly, the LLM server responds with a selected function and corresponding arguments based on the user prompt + - XGrammar's structural tag ensures the arguments conform to the function definition + +2. **Second step:** + - The client calls the selected function with the arguments and retrieves the results + - The client provides the chat history and function call results to the LLM server + - The LLM server provides the response based on the function call results + +For example, you can query "What is the weather like in SF?" with the following command: + +```bash +python openai_chat_client_function_calling.py \ + --model <model> \ + --prompt "What is the weather like in SF?" +``` + +The output would look similar to: + +```txt +[USER PROMPT] What is the weather like in SF? +[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in SF?" They want the weather in SF. SF likely refers to San Francisco. We need to get the current weather. We can use get_current_weather function. We need to provide location string "San Francisco, CA". We can also ask for format? By default celsius. But maybe user expects Fahrenheit? They didn't specify. We can provide celsius or Fahrenheit. We can choose default celsius. But maybe better to provide Fahrenheit because US. But default is celsius. We can provide both? We can call function with format "fahrenheit" to be user-friendly. But the function default is celsius. We can override. Let's call get_current_weather with location "San Francisco, CA" and format "fahrenheit". Then we will get the weather. Then we will respond with friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>{ + "location": "San Francisco, CA", + "format": "fahrenheit" +}<|call|> +[FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA', 'format': 'fahrenheit'}) +[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in SF?" We have fetched the weather: sunny true, temperature 68 (F). We need to respond in a friendly tone. Provide a friendly answer: "It's sunny and 68°F in San Francisco." Possibly add a friendly comment. Also ask if they want more details.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! It’s a pleasant 68 °F in San Francisco right now, and the sun is shining. 🌞 + +Anything else you'd like to know about the weather or maybe some fun things to do in the city today?<|return|> +``` + +The function call works successfully: +- In `[RESPONSE 1]`, the LLM selects the correct function `get_current_weather` and provides the appropriate arguments. +- In `[FUNCTION CALL]`, the client parses the LLM response and executes the function call. +- In `[RESPONSE 2]`, the LLM integrates the function call results into its final answer. + +Let's try another query "What is the weather like in NY and SF?" with the following command: + +```bash +python openai_chat_client_function_calling.py \ + --model <model> \ + --prompt "What is the weather like in NY and SF?" +``` + +The output would look like: + +```txt +[USER PROMPT] What is the weather like in NY and SF? +[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in NY and SF?" They want the weather in New York and San Francisco. We need to provide the weather. We can use the function get_multiple_weathers. We need to provide the list of city and state strings. For New York, we can use "New York, NY". For San Francisco, "San Francisco, CA". We can call get_multiple_weathers with those two locations. We should specify format? The default is celsius. But maybe the user might want Fahrenheit? They didn't specify. We can just use default celsius. But maybe we can provide both? But the function only returns one format. We can just use default celsius. But we can also ask the user? But the user asked "What is the weather like in NY and SF?" We can just provide the weather. We can call the function. Then we will get the weather data. Then we can respond with a friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>{"locations":["New York, NY","San Francisco, CA"]}<|call|> +[FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA']}) +[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in NY and SF?" We called get_multiple_weathers with locations ["New York, NY","San Francisco, CA"]. The function returned: [{"sunny": true, "temperature": 20}, {"sunny": true, "temperature": 20}]. That seems to be a list of two objects, each with sunny: true, temperature: 20. But we need to interpret the function output. The function get_multiple_weathers presumably returns a list of weather data for each location. But the returned data is ambiguous: we don't know which corresponds to which location. But we can assume the order matches the input order: first is New York, second is San Francisco. The temperature is 20 degrees Celsius? The function didn't specify units, but default is celsius. So 20°C. And sunny: true. So both are sunny and 20°C. We should respond in a friendly tone, summarizing the weather for both cities. We can mention that it's sunny and 20°C in both New York and San Francisco. We can also mention that it's a nice day. We can ask if they want more details. We should not mention the function call. We should just provide the answer.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! Here’s the scoop: + +- **New York, NY**: It’s sunny and a comfortable 20 °C (68 °F). +- **San Francisco, CA**: Also sunny with a pleasant 20 °C (68 °F). + +Looks like both coasts are enjoying a bright, mild day. Let me know if you’d like a forecast for later or any other details!<|return|> +``` + +Once again, the function call works successfully, this time using a different function: `get_multiple_weathers`. + +## Using OpenAI Triton Kernels for MoE + +OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels; enable them with the steps below: + +1. **Build and install Triton** (tested with the commit below): + +```bash +git clone https://github.com/triton-lang/triton.git +cd triton +# Specific commit verified with TensorRT-LLM +git checkout f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f +pip install -r python/requirements.txt # build-time dependencies +pip install wheel build +python3 setup.py bdist_wheel +pip install ./dist/*.whl +``` + +2. **Expose the Triton kernels to TensorRT-LLM** + The kernels are not packaged in the wheel, so set the environment variable `TRITON_ROOT` to your Triton clone: + +```bash +export TRITON_ROOT=/local/user/triton +# TensorRT-LLM expects the kernels at: +# $TRITON_ROOT/python/triton_kernels +``` + +3. **Select Triton as the MoE backend** + +• **trtllm-serve** (or other similar commands) — add this snippet to the YAML file passed via `--extra_llm_api_options`: + +```yaml +moe_config: + backend: TRITON +``` + +• **Example scripts** (e.g. `examples/llm-api/quickstart_advanced.py`) — pass the CLI flag: + +```bash +--moe_backend TRITON +``` diff --git a/examples/models/core/gpt_oss/openai_chat_client_function_calling.py b/examples/models/core/gpt_oss/openai_chat_client_function_calling.py new file mode 100644 index 0000000000..6450688ab8 --- /dev/null +++ b/examples/models/core/gpt_oss/openai_chat_client_function_calling.py @@ -0,0 +1,191 @@ +import argparse +import json +import re + +from openai import OpenAI + +system_prompt = """You are ChatGPT, a large language model trained by OpenAI. +Knowledge cutoff: 2024-06 +Current date: 2025-06-28 + +Reasoning: high + +# Valid channels: analysis, commentary, final. Channel must be included for every message. +Calls to these tools must go to the commentary channel: 'functions'.""" + +developer_prompt = """# Instructions + +Use a friendly tone. + +# Tools + +## functions + +namespace functions { + +// Gets the location of the user. +type get_location = () => any; + +// Gets the current weather in the provided location. +type get_current_weather = (_: { +// The city and state, e.g. San Francisco, CA +location: string, +format?: "celsius" | "fahrenheit", // default: celsius +}) => any; + +// Gets the current weather in the provided list of locations. +type get_multiple_weathers = (_: { +// List of city and state, e.g. ["San Francisco, CA", "New York, NY"] +locations: string[], +format?: "celsius" | "fahrenheit", // default: celsius +}) => any; + +} // namespace functions""" + +schema_get_current_weather = { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "description": "default: celsius", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], +} + +schema_get_multiple_weathers = { + "type": "object", + "properties": { + "locations": { + "type": + "array", + "items": { + "type": "string" + }, + "description": + 'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]', + }, + "format": { + "type": "string", + "description": "default: celsius", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["locations"], +} + + +def get_current_weather(location: str, format: str = "celsius") -> dict: + return {"sunny": True, "temperature": 20 if format == "celsius" else 68} + + +def get_multiple_weathers(locations: list[str], + format: str = "celsius") -> list[dict]: + return [get_current_weather(location, format) for location in locations] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--prompt", + type=str, + default="What is the weather like in SF?") + args = parser.parse_args() + + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="tensorrt_llm", + ) + + messages = [ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "developer", + "content": developer_prompt, + }, + { + "role": "user", + "content": args.prompt, + }, + ] + + print(f"[USER PROMPT] {args.prompt}") + chat_completion = client.chat.completions.create( + model=args.model, + messages=messages, + max_completion_tokens=500, + response_format={ + "type": + "structural_tag", + "structures": [{ + "begin": + "<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>", + "schema": schema_get_current_weather, + "end": "<|call|>", + }, { + "begin": + "<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>", + "schema": schema_get_multiple_weathers, + "end": "<|call|>", + }], + "triggers": ["<|channel|>commentary to="], + }, + stop=["<|call|>"], + extra_body={ + "skip_special_tokens": False, + "include_stop_str_in_output": True, + }, + ) + + response_text = chat_completion.choices[0].message.content + print(f"[RESPONSE 1] {response_text}") + + for regex, tool in [ + (r"(<\|channel\|>commentary to=get_current_weather <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)", + get_current_weather), + (r"(<\|channel\|>commentary to=get_multiple_weathers <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)", + get_multiple_weathers) + ]: + match = re.search(regex, response_text) + if match is not None: + break + else: + print("Failed to call functions, exiting...") + return + + kwargs = json.loads(match.group(2)) + print(f"[FUNCTION CALL] {tool.__name__}(**{kwargs})") + answer = tool(**kwargs) + + messages.extend([{ + "role": "assistant", + "content": match.group(0), + }, { + "role": f"{tool.__name__} to=assistant", + "content": json.dumps(answer), + }]) + + chat_completion = client.chat.completions.create( + model=args.model, + messages=messages, + max_completion_tokens=500, + extra_body={ + "skip_special_tokens": False, + "include_stop_str_in_output": True, + }, + ) + + response_text = chat_completion.choices[0].message.content + print(f"[RESPONSE 2] {response_text}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/core/kimi_k2/README.md b/examples/models/core/kimi_k2/README.md new file mode 100644 index 0000000000..1dd3e353c5 --- /dev/null +++ b/examples/models/core/kimi_k2/README.md @@ -0,0 +1,127 @@ +# K2 (Kimi-K2-Instruct) + +## Overview + +Kimi K2 is Moonshot AI's Mixture-of-Experts model with 32 billion activated parameters and 1 trillion total parameters. It achieves state-of-the-art performance in frontier knowledge, math, and coding among non-thinking models. Notably, K2 also excels in agentic capabilities, demonstrating outstanding performance across complex, multi-step tasks. + +## Prerequisites for Tool Calling in Kimi-K2 + +K2 model supports tool calling functionality. The official guide can be found at: [tool_call_guidance](https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md) + +As described in the official guide, a tool calling process in Kimi-K2 includes: +1. Passing function descriptions to Kimi-K2. +2. Kimi-K2 decides to make a function call and returns the necessary information for the function call to the user. +3. The user performs the function call, collects the call results, and passes the function call results to Kimi-K2 +4. Kimi-K2 continues to generate content based on the function call results until the model believes it has obtained sufficient information to respond to the user + +Tools are the primary way to define callable functions for K2. Each tool requires: +- A unique name +- A clear description +- A JSON schema defining the expected parameters + +A possible example of tool description(you may refer to [Using tools](https://huggingface.co/docs/hugs/guides/function-calling) for more information) is as follows: +```python +# Collect the tool descriptions in tools +tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information. Call this tool when the user needs to get weather information", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": { + "type": "string", + "description": "location name", + } + } + } + } +}] +``` + +Kimi currently supports two main approaches for tool calling: +1. *Use openai.OpenAI to send messages to Kimi-K2 together with tool descriptions.* +In this mode, the descriptions of the tools are passed as an argument to `client.chat.completions.create`, and the tool-call details can be read directly from the corresponding fields in the response. +2. *Manually parse the tool-call requests from the outputs generated by Kimi-K2.* +The tool call requests generated by Kimi-K2 are wrapped by <|tool_calls_section_begin|> and <|tool_calls_section_end|>, with each tool call wrapped by <|tool_call_begin|> and <|tool_call_end|>. The tool ID and arguments are separated by <|tool_call_argument_begin|>. The format of the tool ID is functions.{func_name}:{idx}, from which we can parse the function name. + +**Note that TensorRT-LLM does not support the first approach for now. If you deploy K2 with TensorRT-LLM, you need to manually parse the tool-call requests from the outputs.** + +The next section is an example that deploys the K2 model using TensorRT-LLM and then manually parses the tool-call results. + +## Example: Manually Parsing Tool-Call Requests from Kimi-K2 Outputs + +First, launch a server using trtllm-serve: + +```bash +cat > ./extra_llm_api_options.yaml <<EOF +# define your extra parameters here +cuda_graph_config: + batch_sizes: + - 1 + - 4 +enable_attention_dp: False +EOF + +trtllm-serve \ + --model /path_to_model/Kimi-K2-Instruct/ \ + --backend pytorch \ + --tp_size 8 \ + --ep_size 8 \ + --extra_llm_api_options extra_llm_api_options.yaml +``` + +Run the script [kimi_k2_tool_calling_example.py](./kimi_k2_tool_calling_example.py), which performs the following steps: + +1. The client provides tool definitions and a user prompt to the LLM server. +2. Instead of answering the prompt directly, the LLM server responds with a selected tool and corresponding arguments based on the user prompt. +3. The client calls the selected tool with the arguments and retrieves the results. + +For example, you can query "What's the weather like in shanghai today?" with the following command: + +```bash +python kimi_k2_tool_calling_example.py \ + --model "moonshotai/Kimi-K2-Instruct" \ + --prompt "What's the weather like in shanghai today?" +``` + +The output would look similar to: + +```txt +[The original output from Kimi-K2]: <|tool_calls_section_begin|> +<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "shanghai"}<|tool_call_end|> +<|tool_calls_section_end|>user + +[The tool-call requests parsed from the output]: [{'id': 'functions.get_weather:0', 'type': 'function', 'function': {'name': 'get_weather', 'arguments': '{"location": "shanghai"}'}}] + +[Tool call result]: tool_name=get_weather, tool_result=Cloudy +``` + +The tool call works successfully: +- In `[The original output from Kimi-K2]`, the LLM selects the correct tool `get_weather` and provides the appropriate arguments. +- In `[The tool-call requests parsed from the output]`, the client parses the LLM response. +- In `[Tool call result]`, the client executes the tool function and get the result. + +Let's try another query, "What's the weather like in beijing today?", using a predefined system prompt to specify the output format as shown below. + +```bash +python kimi_k2_tool_calling_example.py \ + --model "moonshotai/Kimi-K2-Instruct" \ + --prompt "What's the weather like in beijing today?" + --specify_output_format +``` + +The output would look like: + +```txt +[The original output from Kimi-K2]: [get_weather(location='beijing')]user + +[The tool-call requests parsed from the output]: [{'type': 'function', 'function': {'name': 'get_weather', 'arguments': {'location': 'beijing'}}}] + +[Tool call result]: tool_name=get_weather, tool_result=Sunny +``` +Once again, the tool call works successfully and the original output from Kimi-K2 is formatted. + +**Note that, without guided decoding or other deterministic tool adapters, K2 sometimes deviates from the specified output format. Because TensorRT-LLM does not support K2 with guided decoding for now, you have to parse the tool calls carefully from the raw model output to ensure they meet the required format.** diff --git a/examples/models/core/kimi_k2/kimi_k2_tool_calling_example.py b/examples/models/core/kimi_k2/kimi_k2_tool_calling_example.py new file mode 100644 index 0000000000..2850547704 --- /dev/null +++ b/examples/models/core/kimi_k2/kimi_k2_tool_calling_example.py @@ -0,0 +1,201 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import ast +import json +import re + +from openai import OpenAI + +SPECIFY_OUTPUT_FORMAT_PROMPT = """You are an AI assistant with the role name "assistant." \ +Based on the provided API specifications and conversation history from steps 1 to t, \ +generate the API requests that the assistant should call in step t+1. \ +The API requests should be output in the format [api_name(key1='value1', key2='value2', ...)], \ +replacing api_name with the actual API name, key1, key2, etc., with the actual parameter names, \ +and value1, value2, etc., with the actual parameter values. The output should start with a square bracket "[" and end with a square bracket "]". +If there are multiple API requests, separate them with commas, for example: \ +[api_name(key1='value1', key2='value2', ...), api_name(key1='value1', key2='value2', ...), ...]. \ +Do not include any other explanations, prompts, or API call results in the output. +If the API parameter description does not specify otherwise, the parameter is optional \ +(parameters mentioned in the user input need to be included in the output; if not mentioned, they do not need to be included). +If the API parameter description does not specify the required format for the value, use the user's original text for the parameter value. \ +If the API requires no parameters, output the API request directly in the format [api_name()], and do not invent any nonexistent parameter names. + +API Specifications: +{tools}""" + +NOT_SPECIFY_OUTPUT_FORMAT_PROMPT = """Important: Only give the tool call requests, \ +do not include any other explanations, prompts, or API call results in the output. +The tool call requests generated by you are wrapped by \ +<|tool_calls_section_begin|> and <|tool_calls_section_end|>, with each tool call wrapped by <|tool_call_begin|> and <|tool_call_end|>. \ +The tool ID and arguments are separated by <|tool_call_argument_begin|>. The format of the tool ID is functions.func_name:idx, \ +from which we can parse the function name. + +API Specifications: +{tools}""" + + +def get_weather(location: str): + if location.lower() == "beijing": + return "Sunny" + elif location.lower() == "shanghai": + return "Cloudy" + else: + return "Rainy" + + +# Tool name->object mapping for easy calling later +tool_map = {"get_weather": get_weather} + + +# ref: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md +def extract_tool_call_info(tool_call_rsp: str): + if '<|tool_calls_section_begin|>' not in tool_call_rsp: + # No tool calls + return [] + pattern = r"<\|tool_calls_section_begin\|>(.*?)<\|tool_calls_section_end\|>" + + tool_calls_sections = re.findall(pattern, tool_call_rsp, re.DOTALL) + + # Extract multiple tool calls + func_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*?)\s*<\|tool_call_end\|>" + tool_calls = [] + for match in re.findall(func_call_pattern, tool_calls_sections[0], + re.DOTALL): + function_id, function_args = match + # function_id: functions.get_weather:0 + function_name = function_id.split('.')[1].split(':')[0] + tool_calls.append({ + "id": function_id, + "type": "function", + "function": { + "name": function_name, + "arguments": function_args + } + }) + return tool_calls + + +def parse_specified_format_tool_calls(text: str): + pattern = re.compile(r'(\w+)\s*\(([^)]*)\)') + tool_calls = [] + + for m in pattern.finditer(text): + api_name, kv_body = m.group(1), m.group(2) + + kv_pattern = re.compile(r'(\w+)\s*=\s*([^,]+)') + kwargs = {} + for k, v in kv_pattern.findall(kv_body): + try: + kwargs[k] = ast.literal_eval(v.strip()) + except Exception: + kwargs[k] = v.strip() + + tool_calls.append({ + "type": "function", + "function": { + "name": api_name, + "arguments": kwargs + } + }) + + return tool_calls + + +def get_tools(): + # Collect the tool descriptions in tools + return [{ + "type": "function", + "function": { + "name": "get_weather", + "description": + "Get weather information. Call this tool when the user needs to get weather information", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": { + "type": "string", + "description": "Location name", + } + } + } + } + }] + + +def get_tool_call_requests(args, client): + model = args.model + tools = get_tools() + system_prompt = SPECIFY_OUTPUT_FORMAT_PROMPT if args.specify_output_format else NOT_SPECIFY_OUTPUT_FORMAT_PROMPT.format( + tools=tools) + messages = [{ + "role": "system", + "content": system_prompt + }, { + "role": "user", + "content": args.prompt + }] + + response = client.chat.completions.create(model=model, + messages=messages, + max_tokens=256, + temperature=0.0) + + output = response.choices[0].message.content + tool_calls = parse_specified_format_tool_calls( + output) if args.specify_output_format else extract_tool_call_info( + output) + print(f"[The original output from Kimi-K2]: {output}\n") + print(f"[The tool-call requests parsed from the output]: {tool_calls}\n") + return tool_calls, messages + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", + type=str, + default="moonshotai/Kimi-K2-Instruct") + parser.add_argument("--prompt", + type=str, + default="What's the weather like in Shanghai today?") + parser.add_argument("--specify_output_format", + action="store_true", + default=False) + + args = parser.parse_args() + + # start trt-llm server before running this script + client = OpenAI( + api_key="tensorrt_llm", + base_url="http://localhost:8000/v1", + ) + + tool_calls, messages = get_tool_call_requests(args, client) + + for tool_call in tool_calls: + tool_name = tool_call['function']['name'] + if args.specify_output_format: + tool_arguments = tool_call['function']['arguments'] + else: + tool_arguments = json.loads(tool_call['function']['arguments']) + tool_function = tool_map[tool_name] + tool_result = tool_function(**tool_arguments) + print( + f"[Tool call result]: tool_name={tool_name}, tool_result={tool_result}\n" + ) diff --git a/examples/models/core/llama/README.md b/examples/models/core/llama/README.md index bef4f60123..b888b287b0 100644 --- a/examples/models/core/llama/README.md +++ b/examples/models/core/llama/README.md @@ -676,7 +676,7 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu_fp8 \ The peak GPU memory consumption when doing FP8 quantizaton is more than 210GB (there is also some activation memory occupation when doing calibration). So you need a node with at least 4 H100(A100) to run the quantization command. After quantization, 2 GPUs are okay to for building and run. -Experimental: use FP8 GEMV to optimize performance in FP8 small-batch-size cases. +Note: use FP8 GEMV to optimize performance in FP8 small-batch-size cases. ```bash # Quantize HF LLaMA 7B into FP8 and export trtllm checkpoint @@ -694,7 +694,7 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp8 \ --gemm_plugin fp8 ``` -**Note**: FP8 gemm plugin is an experimental feature aimed to improve performance in small-batch-size cases(e.g. BS<=4). Although inputs with batch size larger than 4 can be correctly inferenced, the performance may decrease as batch size grows. +**Note**: FP8 gemv plugin uses CUDA cores to compute, by contrast to Tensor Core gemm kernel within cuBLAS. Over last year, as cuBLAS have improved their performance by a lot under small M case for Hopper(sm90), FP8 gemv kernel may or may not surpass cuBLAS, depending on specific gemm problem shape. Nonetheless, we still strongly recommend FP8 gemv kernel for Ada (sm89) as cuBLAS still falls behind gemv on it. ### Groupwise quantization (AWQ/GPTQ) One can enable AWQ/GPTQ INT4 weight only quantization with these options when building engine with `trtllm-build`: diff --git a/examples/models/core/mixtral/requirements.txt b/examples/models/core/mixtral/requirements.txt index 50164ee5d3..fee4da9cf6 100644 --- a/examples/models/core/mixtral/requirements.txt +++ b/examples/models/core/mixtral/requirements.txt @@ -1,4 +1,4 @@ -c ../../../constraints.txt tensorrt_llm>=0.0.0.dev0 -transformers==4.38.2 +transformers==4.54.0 accelerate==0.25.0 diff --git a/examples/models/core/multimodal/requirements-eclair.txt b/examples/models/core/multimodal/requirements-eclair.txt new file mode 100644 index 0000000000..281c8ae93a --- /dev/null +++ b/examples/models/core/multimodal/requirements-eclair.txt @@ -0,0 +1 @@ +timm diff --git a/examples/models/core/multimodal/requirements-qwen2vl.txt b/examples/models/core/multimodal/requirements-qwen2vl.txt index 50f14d1d80..c75f6c32e1 100644 --- a/examples/models/core/multimodal/requirements-qwen2vl.txt +++ b/examples/models/core/multimodal/requirements-qwen2vl.txt @@ -1,2 +1,3 @@ accelerate qwen-vl-utils==0.0.8 # 0.0.9 has bug https://github.com/QwenLM/Qwen2-VL/pull/673, rollback until a newer version is released +transformers==4.51.0 # nvbugs/5385987 diff --git a/examples/models/core/qwen/requirements.txt b/examples/models/core/qwen/requirements.txt index 397e53956d..64ada0fdb3 100644 --- a/examples/models/core/qwen/requirements.txt +++ b/examples/models/core/qwen/requirements.txt @@ -10,7 +10,7 @@ tiktoken einops # optional dependencies -gradio==4.36.0 +gradio==4.44.1 mdtex2html sse_starlette aiohttp_sse_client diff --git a/examples/quantization/quantize_mixed_precision_moe.py b/examples/quantization/quantize_mixed_precision_moe.py index b931f408dc..fb7f94b858 100644 --- a/examples/quantization/quantize_mixed_precision_moe.py +++ b/examples/quantization/quantize_mixed_precision_moe.py @@ -45,10 +45,16 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): state_dict_list = [] # load amax from state dict for rank in range(world_size): - state_dict_list.append( - torch.load( - f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt", - map_location="cuda:0")) + amax_file = f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt" + if os.path.exists(amax_file): + state_dict_list.append(torch.load(amax_file, map_location="cuda:0")) + else: + print(f"WARNING: amax file not found: {amax_file}") + + if not state_dict_list: + print("ERROR: No amax files loaded!") + return {} + # calculate the max across all TP ranks merged_state_dict = state_dict_list[0] for rank in range(world_size): @@ -232,15 +238,18 @@ def main(args): continue new_safetensors.update({key: get_tensor(key)}) + # Process activation scales for all ranks + if os.path.isdir(args.act_scales): + # Extract activation scales + renamed_state_dict = load_and_preprocess_state_dict( + modelopt_state_root=args.act_scales, world_size=8) + scales = get_scales_from_amax(start_layer=start_layer, + end_layer=end_layer, + renamed_state_dict=renamed_state_dict) + new_safetensors.update(scales) + if args.rank == 0: - if os.path.isdir(args.act_scales): - # Extract activation scales - renamed_state_dict = load_and_preprocess_state_dict( - modelopt_state_root=args.act_scales, world_size=8) - get_scales_from_amax(start_layer=start_layer, - end_layer=end_layer, - renamed_state_dict=renamed_state_dict) - else: + if not os.path.isdir(args.act_scales): input_scales = safe_open(args.act_scales, "pt") for k in input_scales.keys(): new_safetensors.update({k: input_scales.get_tensor(k)}) @@ -259,7 +268,10 @@ def main(args): ] for name in names: shutil.copy(os.path.join(model_dir, name), output_dir) - shutil.copy(args.act_scales, output_dir) + if os.path.isdir(args.act_scales): + shutil.copytree(args.act_scales, output_dir, dirs_exist_ok=True) + else: + shutil.copy(args.act_scales, output_dir) # config.json del config['quantization_config'] diff --git a/examples/sample_weight_stripping/README.md b/examples/sample_weight_stripping/README.md index bd28a60b84..a005f0904b 100644 --- a/examples/sample_weight_stripping/README.md +++ b/examples/sample_weight_stripping/README.md @@ -12,7 +12,7 @@ * [Llama-7b FP16 + WoQ INT8](#llama-7b-fp16-woq-int8) * [Llama2-70b FP8 with TP=2](#llama2-70b-fp8-with-tp2) - [Engine Plan File Size Results](#engine-plan-file-size-results) -- [Experimental](#experimental) +- [Prototype](#prototype) * [Checkpoint Pruner](#checkpoint-pruner) * [Pruning a TensorRT-LLM Checkpoint](#pruning-a-tensorrt-llm-checkpoint) @@ -239,7 +239,7 @@ python3 ../summarize.py --engine_dir engines/llama2-70b-hf-fp8-tp2.refit \ |llama-7b FP16 + WoQ INT8 | 6.54GB | 28.69MB | |llama2-70b FP8 + TP=2 | 64.78GB | 60.61MB | -## Experimental +## Prototype ### Checkpoint Pruner The checkpoint pruner allows you to strip `Conv` and `Gemm` weights out of a TensorRT-LLM [checkpoint](https://nvidia.github.io/TensorRT-LLM/latest/architecture/checkpoint.html). Since these make up the vast majority of weights, the pruner will decrease the size of your checkpoint up to 99%. diff --git a/examples/serve/openai_completion_client_json_schema.py b/examples/serve/openai_completion_client_json_schema.py index 2f110270f5..56e5a351a0 100644 --- a/examples/serve/openai_completion_client_json_schema.py +++ b/examples/serve/openai_completion_client_json_schema.py @@ -1,5 +1,9 @@ ### :title OpenAI Completion Client with JSON Schema +# This example requires to specify `guided_decoding_backend` as +# `xgrammar` or `llguidance` in the extra_llm_api_options.yaml file. +import json + from openai import OpenAI client = OpenAI( @@ -18,7 +22,6 @@ response = client.chat.completions.create( "content": f"Give me the information of the biggest city of China in the JSON format.", }], - max_tokens=100, temperature=0, response_format={ "type": "json", @@ -39,4 +42,11 @@ response = client.chat.completions.create( } }, ) -print(response.choices[0].message.content) + +content = response.choices[0].message.content +try: + response_json = json.loads(content) + assert "name" in response_json and "population" in response_json + print(content) +except json.JSONDecodeError: + print("Failed to decode JSON response") diff --git a/examples/wide_ep/slurm_scripts/README.md b/examples/wide_ep/slurm_scripts/README.md index 3bd5e926b2..d45f66bba8 100644 --- a/examples/wide_ep/slurm_scripts/README.md +++ b/examples/wide_ep/slurm_scripts/README.md @@ -17,7 +17,7 @@ Please note that: ### Core Scripts -Note that, core implementation of the slurm scripts are included in `examples/disaggregated/slurm`. +Note that, core implementation of the slurm scripts are included in `examples/disaggregated/slurm/benchmark`. 1. `submit.sh` - Main entry point for submitting benchmark jobs 2. `process_gen_iterlog.py` - Processes benchmark results and generates reports @@ -35,8 +35,8 @@ Before running the scripts, ensure you have: ### Running Benchmarks ```bash -# Refer to `examples/disaggregated/slurm/` -# Please find the `disaggr_torch.slurm` script in the `examples/disaggregated/slurm/` directory. +# Refer to `examples/disaggregated/slurm/benchmark/` +# Please find the `disaggr_torch.slurm` script in the `examples/disaggregated/slurm/benchmark/` directory. # Make sure that SLURM parameters are correctly set in `disaggr_torch.slurm` before executing this script. ./submit.sh ``` diff --git a/examples/wide_ep/slurm_scripts/submit.sh b/examples/wide_ep/slurm_scripts/submit.sh index f83fa53b55..f5b887f812 100644 --- a/examples/wide_ep/slurm_scripts/submit.sh +++ b/examples/wide_ep/slurm_scripts/submit.sh @@ -1,6 +1,6 @@ #!/bin/bash -echo "Please find the \`disaggr_torch.slurm\` script in the \`examples/disaggregated/slurm/\` directory." +echo "Please find the \`disaggr_torch.slurm\` script in the \`examples/disaggregated/slurm/benchmark/\` directory." partition=<partition> account=<account> @@ -9,9 +9,10 @@ container_image=<container_image> mounts=<mounts> # e.g. /mnt/data:/mnt/data workdir=<workdir> # Path to disaggr_torch.slurm model_dir=<model_dir> # Path to the model checkpoint +repo_dir=<repo_dir> # Path to the repo to install TensorRT-LLM, if this is empty, the pre-installed version will be used mtp_size=0 -ntasks_per_node=4 # 4 GPUs per GB200 node +ntasks_per_node=4 # 4 GPUs per GB200 node, 8 GPUs per B200 node isl=1024 osl=1024 @@ -22,13 +23,14 @@ streaming=true for b in 1 64 1024; do for eplb_num_slots in 0 256 288; do concurrency=$((b * 16)) - ctx_num=$(((concurrency + 5499)/5500)) - total_node_num=$((ctx_num + 4)) + ctx_node_num=$(((concurrency + 5499)/5500)) # $(((concurrency + 10999)/11000)) for B200 + ctx_num=${ctx_node_num} # $((ctx_node_num * 2)) for B200 + total_node_num=$((ctx_node_num + 4)) # $((ctx_node_num + 2)) for B200 ntasks=$((total_node_num * ntasks_per_node)) args=( ${ctx_num} 4 4 4480 true # Context servers arguments - 1 16 1024 1024 "0.7" # Generation servers arguments + 1 16 1024 1024 true "0.7" # Generation servers arguments $eplb_num_slots $mtp_size # Other arguments $concurrency # Benchmarking arguments $isl @@ -39,6 +41,7 @@ for b in 1 64 1024; do $mounts $workdir $model_dir + $repo_dir ) sbatch --nodes=${total_node_num} \ @@ -56,8 +59,9 @@ done # dep32 eplb288 for b in 512; do concurrency=$((b * 32)) - ctx_num=$(((concurrency + 5499)/5500)) - total_node_num=$((ctx_num + 8)) + ctx_node_num=$(((concurrency + 5499)/5500)) # $(((concurrency + 10999)/11000)) for B200 + ctx_num=${ctx_node_num} # $((ctx_node_num * 2)) for B200 + total_node_num=$((ctx_node_num + 8)) # $((ctx_node_num + 4)) for B200 ntasks=$((total_node_num * ntasks_per_node)) eplb_num_slots=288 @@ -74,6 +78,7 @@ for b in 512; do $mounts $workdir $model_dir + $repo_dir ) sbatch --nodes=${total_node_num} \ diff --git a/jenkins/BuildDockerImage.groovy b/jenkins/BuildDockerImage.groovy index 5aa61708f5..64e03de476 100644 --- a/jenkins/BuildDockerImage.groovy +++ b/jenkins/BuildDockerImage.groovy @@ -258,7 +258,7 @@ def buildImage(config, imageKeyToTag) // Step 2: Build the images stage ("Install packages") { sh "pwd && ls -alh" - sh "env" + sh "env | sort" sh "apk add make git" sh "git config --global --add safe.directory '*'" @@ -281,23 +281,31 @@ def buildImage(config, imageKeyToTag) try { def build_jobs = BUILD_JOBS // Fix the triton image pull timeout issue - def TRITON_IMAGE = sh(script: "cd ${LLM_ROOT} && grep 'ARG TRITON_IMAGE=' docker/Dockerfile.multi | grep -o '=.*' | tr -d '=\"'", returnStdout: true).trim() - def TRITON_BASE_TAG = sh(script: "cd ${LLM_ROOT} && grep 'ARG TRITON_BASE_TAG=' docker/Dockerfile.multi | grep -o '=.*' | tr -d '=\"'", returnStdout: true).trim() + def BASE_IMAGE = sh(script: "cd ${LLM_ROOT} && grep '^ARG BASE_IMAGE=' docker/Dockerfile.multi | grep -o '=.*' | tr -d '=\"'", returnStdout: true).trim() + def TRITON_IMAGE = sh(script: "cd ${LLM_ROOT} && grep '^ARG TRITON_IMAGE=' docker/Dockerfile.multi | grep -o '=.*' | tr -d '=\"'", returnStdout: true).trim() + def TRITON_BASE_TAG = sh(script: "cd ${LLM_ROOT} && grep '^ARG TRITON_BASE_TAG=' docker/Dockerfile.multi | grep -o '=.*' | tr -d '=\"'", returnStdout: true).trim() + + if (target == "rockylinux8") { + BASE_IMAGE = sh(script: "cd ${LLM_ROOT} && grep '^jenkins-rockylinux8_%: BASE_IMAGE =' docker/Makefile | grep -o '=.*' | tr -d '=\"'", returnStdout: true).trim() + } + + // Replace the base image and triton image with the internal mirror + BASE_IMAGE = BASE_IMAGE.replace("nvcr.io/", "urm.nvidia.com/docker/") + TRITON_IMAGE = TRITON_IMAGE.replace("nvcr.io/", "urm.nvidia.com/docker/") if (dependent) { stage ("make ${dependent.target}_${action} (${arch})") { - retry(3) { - sh "docker pull ${TRITON_IMAGE}:${TRITON_BASE_TAG}" - } - retry(3) { - sh """ - cd ${LLM_ROOT} && make -C docker ${dependent.target}_${action} \ - TORCH_INSTALL_TYPE=${torchInstallType} \ - IMAGE_WITH_TAG=${dependentImageWithTag} \ - STAGE=${dependent.dockerfileStage} \ - BUILD_WHEEL_OPTS='-j ${build_jobs}' ${args} - """ - } + def randomSleep = (Math.random() * 300 + 300).toInteger() + trtllm_utils.llmExecStepWithRetry(this, script: "docker pull ${TRITON_IMAGE}:${TRITON_BASE_TAG}", sleepInSecs: randomSleep, shortCommondRunTimeMax: 7200) + trtllm_utils.llmExecStepWithRetry(this, script: """ + cd ${LLM_ROOT} && make -C docker ${dependent.target}_${action} \ + BASE_IMAGE=${BASE_IMAGE} \ + TRITON_IMAGE=${TRITON_IMAGE} \ + TORCH_INSTALL_TYPE=${torchInstallType} \ + IMAGE_WITH_TAG=${dependentImageWithTag} \ + STAGE=${dependent.dockerfileStage} \ + BUILD_WHEEL_OPTS='-j ${build_jobs}' ${args} + """, sleepInSecs: randomSleep, numRetries: 3, shortCommondRunTimeMax: 7200) args += " DEVEL_IMAGE=${dependentImageWithTag}" if (target == "ngc-release") { imageKeyToTag["NGC Devel Image ${config.arch}"] = dependentImageWithTag @@ -315,18 +323,18 @@ def buildImage(config, imageKeyToTag) } } stage ("make ${target}_${action} (${arch})") { - retry(3) { - sh "docker pull ${TRITON_IMAGE}:${TRITON_BASE_TAG}" - } - retry(3) { - sh """ - cd ${LLM_ROOT} && make -C docker ${target}_${action} \ - TORCH_INSTALL_TYPE=${torchInstallType} \ - IMAGE_WITH_TAG=${imageWithTag} \ - STAGE=${dockerfileStage} \ - BUILD_WHEEL_OPTS='-j ${build_jobs}' ${args} - """ - } + sh "env | sort" + def randomSleep = (Math.random() * 300 + 300).toInteger() + trtllm_utils.llmExecStepWithRetry(this, script: "docker pull ${TRITON_IMAGE}:${TRITON_BASE_TAG}", sleepInSecs: randomSleep, shortCommondRunTimeMax: 7200) + trtllm_utils.llmExecStepWithRetry(this, script: """ + cd ${LLM_ROOT} && make -C docker ${target}_${action} \ + BASE_IMAGE=${BASE_IMAGE} \ + TRITON_IMAGE=${TRITON_IMAGE} \ + TORCH_INSTALL_TYPE=${torchInstallType} \ + IMAGE_WITH_TAG=${imageWithTag} \ + STAGE=${dockerfileStage} \ + BUILD_WHEEL_OPTS='-j ${build_jobs}' ${args} + """, sleepInSecs: randomSleep, numRetries: 3, shortCommondRunTimeMax: 7200) if (target == "ngc-release") { imageKeyToTag["NGC Release Image ${config.arch}"] = imageWithTag } @@ -336,6 +344,8 @@ def buildImage(config, imageKeyToTag) stage ("custom tag: ${customTag} (${arch})") { sh """ cd ${LLM_ROOT} && make -C docker ${target}_${action} \ + BASE_IMAGE=${BASE_IMAGE} \ + TRITON_IMAGE=${TRITON_IMAGE} \ TORCH_INSTALL_TYPE=${torchInstallType} \ IMAGE_WITH_TAG=${customImageWithTag} \ STAGE=${dockerfileStage} \ diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index cefb06508c..d00dd66d53 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -339,7 +339,9 @@ def mergeWaiveList(pipeline, globalVars) LLM_TOT_ROOT = "llm-tot" targetBranch = env.gitlabTargetBranch ? env.gitlabTargetBranch : globalVars[TARGET_BRANCH] echo "Target branch: ${targetBranch}" - trtllm_utils.checkoutSource(LLM_REPO, targetBranch, LLM_TOT_ROOT, true, true) + withCredentials([string(credentialsId: 'default-sync-llm-repo', variable: 'DEFAULT_SYNC_LLM_REPO')]) { + trtllm_utils.checkoutSource(DEFAULT_SYNC_LLM_REPO, targetBranch, LLM_TOT_ROOT, false, false) + } targetBranchTOTCommit = sh (script: "cd ${LLM_TOT_ROOT} && git rev-parse HEAD", returnStdout: true).trim() echo "Target branch TOT commit: ${targetBranchTOTCommit}" sh "cp ${LLM_TOT_ROOT}/tests/integration/test_lists/waives.txt ./waives_TOT_${targetBranchTOTCommit}.txt" @@ -384,30 +386,11 @@ def launchReleaseCheck(pipeline) -y""") sh "pip3 config set global.break-system-packages true" sh "git config --global --add safe.directory \"*\"" - // Step 1: cloning tekit source code + // Step 1: Clone TRT-LLM source codes trtllm_utils.checkoutSource(LLM_REPO, env.gitlabCommit, LLM_ROOT, true, true) sh "cd ${LLM_ROOT} && git config --unset-all core.hooksPath" - trtllm_utils.llmExecStepWithRetry(pipeline, script: "cd ${LLM_ROOT} && python3 -u scripts/release_check.py || (git restore . && false)") - // Step 2: build tools - withEnv(['GONOSUMDB=*.nvidia.com']) { - withCredentials([ - gitUsernamePassword( - credentialsId: 'svc_tensorrt_gitlab_read_api_token', - gitToolName: 'git-tool' - ), - string( - credentialsId: 'default-git-url', - variable: 'DEFAULT_GIT_URL' - ) - ]) { - sh "go install ${DEFAULT_GIT_URL}/TensorRT/Infrastructure/licensechecker/cmd/license_checker@v0.3.0" - } - } - // Step 3: Run license check - sh "cd ${LLM_ROOT}/cpp && /go/bin/license_checker -config ../jenkins/license_cpp.json include tensorrt_llm" - - // Step 4: Run guardwords scan + // Step 2: Run guardwords scan def isOfficialPostMergeJob = (env.JOB_NAME ==~ /.*PostMerge.*/) if (env.alternativeTRT || isOfficialPostMergeJob) { trtllm_utils.checkoutSource(SCAN_REPO, SCAN_COMMIT, SCAN_ROOT, true, true) @@ -432,6 +415,26 @@ def launchReleaseCheck(pipeline) echo "Guardwords Scan Results: https://urm.nvidia.com/artifactory/${UPLOAD_PATH}/guardwords-scan-results/scan.log" } } + + // Step 3: Run pre-commit checks + trtllm_utils.llmExecStepWithRetry(pipeline, script: "cd ${LLM_ROOT} && python3 -u scripts/release_check.py || (git restore . && false)") + + // Step 4: Run license check + withEnv(['GONOSUMDB=*.nvidia.com']) { + withCredentials([ + gitUsernamePassword( + credentialsId: 'svc_tensorrt_gitlab_read_api_token', + gitToolName: 'git-tool' + ), + string( + credentialsId: 'default-git-url', + variable: 'DEFAULT_GIT_URL' + ) + ]) { + sh "go install ${DEFAULT_GIT_URL}/TensorRT/Infrastructure/licensechecker/cmd/license_checker@v0.3.0" + } + } + sh "cd ${LLM_ROOT}/cpp && /go/bin/license_checker -config ../jenkins/license_cpp.json include tensorrt_llm" } def image = "urm.nvidia.com/docker/golang:1.22" @@ -588,6 +591,12 @@ def getMergeRequestChangedFileList(pipeline, globalVars) { } def getMergeRequestOneFileChanges(pipeline, globalVars, filePath) { + def isOfficialPostMergeJob = (env.JOB_NAME ==~ /.*PostMerge.*/) + if (env.alternativeTRT || isOfficialPostMergeJob) { + pipeline.echo("Force set changed file diff to empty string.") + return "" + } + def githubPrApiUrl = globalVars[GITHUB_PR_API_URL] def diff = "" diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 16b8029aa7..f89a26c98c 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -42,7 +42,7 @@ LLM_DOCKER_IMAGE_12_9 = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch- LLM_SBSA_DOCKER_IMAGE_12_9 = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508051130-6090" // DLFW torch image -DLFW_IMAGE = "nvcr.io/nvidia/pytorch:25.06-py3" +DLFW_IMAGE = "urm.nvidia.com/docker/nvidia/pytorch:25.06-py3" //Ubuntu base image UBUNTU_22_04_IMAGE = "urm.nvidia.com/docker/ubuntu:22.04" @@ -110,6 +110,8 @@ MODEL_CACHE_DIR="/scratch.trt_llm_data/llm-models" ENABLE_NGC_DEVEL_IMAGE_TEST = params.enableNgcDevelImageTest ?: false ENABLE_NGC_RELEASE_IMAGE_TEST = params.enableNgcReleaseImageTest ?: false +COMMON_SSH_OPTIONS = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null" + def uploadResults(def pipeline, SlurmCluster cluster, String nodeName, String stageName){ withCredentials([usernamePassword(credentialsId: 'svc_tensorrt', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { def remote = [ @@ -124,7 +126,7 @@ def uploadResults(def pipeline, SlurmCluster cluster, String nodeName, String st pipeline.stage('Submit Test Results') { sh "mkdir -p ${stageName}" def resultsFilePath = "/home/svc_tensorrt/bloom/scripts/${nodeName}/results/results.xml" - def downloadResultCmd = "sshpass -p '${remote.passwd}' scp -r -p -oStrictHostKeyChecking=no ${remote.user}@${remote.host}:${resultsFilePath} ${stageName}/" + def downloadResultCmd = "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${remote.user}@${remote.host}:${resultsFilePath} ${stageName}/" def downloadSucceed = sh(script: downloadResultCmd, returnStatus: true) == 0 if (downloadSucceed) { sh "ls ${stageName}" @@ -250,7 +252,7 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p Utils.exec(pipeline, script: "chmod +x ${jenkinsSetupPath}", returnStdout: true) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p -oStrictHostKeyChecking=no ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh",) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh",) Utils.exec( pipeline, @@ -338,7 +340,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL stage('Prepare Testing') { // Create Job Workspace folder in Frontend Node - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' ssh -oStrictHostKeyChecking=no ${remote.user}@${remote.host} 'mkdir ${jobWorkspace}'",) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' ssh ${COMMON_SSH_OPTIONS} ${remote.user}@${remote.host} 'mkdir -p ${jobWorkspace}'",) // Download and Unzip Tar File trtllm_utils.llmExecStepWithRetry(pipeline, script: "cd ${llmPath} && wget -nv ${llmTarfile}") @@ -347,11 +349,11 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL // Upload slurm_run_sh to Frontend node def scriptRunLocalPath = "${llmSrcLocal}/jenkins/scripts/slurm_run.sh" Utils.exec(pipeline, script: "chmod +x ${scriptRunLocalPath}", returnStdout: true) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p -oStrictHostKeyChecking=no ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}",) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}",) // Upload waives.txt to Frontend node def waivesListLocalPath = "${llmSrcLocal}/tests/integration/test_lists/waives.txt" - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p -oStrictHostKeyChecking=no ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}",) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}",) // Generate Test List and Upload to Frontend Node def makoArgs = getMakoArgsFromStageName(stageName, true) @@ -360,7 +362,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL // if the line cannot be split by "=", just ignore that line. def makoOptsJson = transformMakoArgsToJson(["Mako options:"] + makoArgs) def testListPath = renderTestDB(testList, llmSrcLocal, stageName, makoOptsJson) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p -oStrictHostKeyChecking=no ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}",) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}",) // Generate Multi Node Job Launch Script def container = LLM_DOCKER_IMAGE.replace("urm.nvidia.com/", "urm.nvidia.com#") @@ -404,7 +406,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL """.stripIndent() pipeline.writeFile(file: scriptLaunchDestPath, text: scriptContent) Utils.exec(pipeline, script: "chmod +x ${scriptLaunchDestPath}", returnStdout: true) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p -oStrictHostKeyChecking=no ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}",) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}",) } stage('Run Test') { def scriptLaunch = "${jobWorkspace}/slurm_launch.sh" @@ -1100,7 +1102,7 @@ def getSSHConnectionPorts(portConfigFile, stageName) usernamePassword(credentialsId: 'tensorrt_llm_infra_debug_vm_01_credentials', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD'), string(credentialsId: 'DEBUG_HOST_NAME', variable: 'HOST_NAME') ]) { - portUsage = sh(script: "ssh -v ${USERNAME}@${HOST_NAME} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null 'netstat -tuln'",returnStdout: true) + portUsage = sh(script: "ssh -v ${USERNAME}@${HOST_NAME} ${COMMON_SSH_OPTIONS} 'netstat -tuln'", returnStdout: true) } echo "Port Usage: ${portUsage}" @@ -1259,7 +1261,7 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO def llmRootConfig = "${LLM_ROOT}${config}" sh "mkdir ${llmRootConfig}" - def llmPath = sh (script: "realpath ${llmRootConfig}",returnStdout: true).trim() + def llmPath = sh (script: "realpath ${llmRootConfig}", returnStdout: true).trim() def llmSrc = "${llmPath}/TensorRT-LLM/src" echoNodeAndGpuInfo(pipeline, stageName) @@ -1376,9 +1378,9 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO usernamePassword(credentialsId: 'tensorrt_llm_infra_debug_vm_01_credentials', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD'), string(credentialsId: 'DEBUG_HOST_NAME', variable: 'HOST_NAME') ]) { - sh "sshpass -p ${PASSWORD} -v ssh ${USERNAME}@${HOST_NAME} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null 'cat >> ~/.ssh/authorized_keys' < ~/.ssh/id_rsa.pub" - sh "ssh -v ${USERNAME}@${HOST_NAME} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null 'echo \"\" > ~/.ssh/known_hosts && cat ~/.ssh/id_rsa.pub' >> ~/.ssh/authorized_keys" - sh "ssh -v ${USERNAME}@${HOST_NAME} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null 'cat ~/.ssh/ports_config.txt' >> ${portConfigFilePath}" + sh "sshpass -p ${PASSWORD} -v ssh ${USERNAME}@${HOST_NAME} ${COMMON_SSH_OPTIONS} 'cat >> ~/.ssh/authorized_keys' < ~/.ssh/id_rsa.pub" + sh "ssh -v ${USERNAME}@${HOST_NAME} ${COMMON_SSH_OPTIONS} 'echo \"\" > ~/.ssh/known_hosts && cat ~/.ssh/id_rsa.pub' >> ~/.ssh/authorized_keys" + sh "ssh -v ${USERNAME}@${HOST_NAME} ${COMMON_SSH_OPTIONS} 'cat ~/.ssh/ports_config.txt' >> ${portConfigFilePath}" def (int userPort, int monitorPort) = getSSHConnectionPorts(portConfigFilePath, stageName) if (userPort == 0) { @@ -1387,7 +1389,7 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO return } - sh "ssh -f -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -L 1111:127.0.0.1:${monitorPort} -R ${monitorPort}:127.0.0.1:1112 -NR ${userPort}:localhost:22 ${USERNAME}@${HOST_NAME}" + sh "ssh -f ${COMMON_SSH_OPTIONS} -L 1111:127.0.0.1:${monitorPort} -R ${monitorPort}:127.0.0.1:1112 -NR ${userPort}:localhost:22 ${USERNAME}@${HOST_NAME}" sh "autossh -fNR ${userPort}:localhost:22 ${USERNAME}@${HOST_NAME}" sh "ps aux | grep ssh" try { @@ -2095,6 +2097,11 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) checkPipStage = true } + if (cpu_arch == AARCH64_TRIPLE && values[5] != DLFW_IMAGE) { + checkPipStage = false + echo "Skip pip install sanity check due to https://nvbugs/5453827" + } + if (checkPipStage) { stage("Run LLMAPI tests") { pipInstallSanitySpec = createKubernetesPodConfig(values[5], gpu_type, k8s_arch) @@ -2128,18 +2135,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install torch==2.7.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128") } - // Workaround for https://nvbugs/5433581 where deep_gemm installation fails on SBSA platform - if (cpu_arch == AARCH64_TRIPLE) { - echo "###### Workaround for https://nvbugs/5433581 Start ######" - def deepGemmLine = readFile("${LLM_ROOT}/requirements.txt").readLines().find { it.trim().startsWith('deep_gemm') } - if (deepGemmLine) { - trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install '${deepGemmLine.trim()}' --extra-index-url https://download.pytorch.org/whl/cu128") - } - else { - echo "deep_gemm package not found in requirements.txt" - } - } - def libEnv = [] if (env.alternativeTRT) { stage("Replace TensorRT") { diff --git a/jenkins/controlCCache.groovy b/jenkins/controlCCache.groovy index 82fa7757ad..bc34d88e4d 100644 --- a/jenkins/controlCCache.groovy +++ b/jenkins/controlCCache.groovy @@ -1,7 +1,7 @@ import java.lang.InterruptedException -DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507251001-5678" +DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508130930-6501" def createKubernetesPodConfig(image, arch = "amd64") { diff --git a/pyproject.toml b/pyproject.toml index b0e25b6ea9..3ee28a7e16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ ############################### BUILD CONFIGURATION ############################################## #################################################################################################### [build-system] -requires = ["setuptools >= 64"] +requires = ["setuptools >= 64", "pip >= 24"] build-backend = "setuptools.build_meta" #################################################################################################### @@ -68,7 +68,7 @@ ignore_patterns = [ [tool.codespell] skip = ".git,3rdparty,tests/integration/test_input_files**,**.jsonl,**.json" exclude-file = "examples/models/core/whisper/tokenizer.py" -ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw" +ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw,dOut" [tool.autoflake] in-place = true diff --git a/requirements.txt b/requirements.txt index 252a27987d..9e7ce380e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ diffusers>=0.27.0 lark mpi4py numpy<2 -onnx>=1.12.0 +onnx>=1.18.0 onnx_graphsurgeon>=0.5.2 openai polygraphy @@ -28,7 +28,9 @@ torchvision nvidia-modelopt[torch]~=0.33.0 nvidia-nccl-cu13 nvidia-cuda-nvrtc-cu13 -transformers==4.53.1 +transformers==4.55.0 +prometheus_client +prometheus_fastapi_instrumentator pydantic>=2.9.1 pydantic-settings[yaml] omegaconf @@ -52,6 +54,8 @@ einops flashinfer-python @ git+https://github.com/VALLIS-NERIA/flashinfer.git@29e57482ad6410444aa86cf92c669dbbb506e978 opencv-python-headless xgrammar==0.1.21 +llguidance==0.7.29 +jsonschema backoff nvtx matplotlib # FIXME: this is added to make nvtx happy @@ -59,7 +63,7 @@ meson ninja etcd3 blake3 -llguidance==0.7.29 soundfile triton==3.3.1; platform_machine == "x86_64" +tiktoken blobfile diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index 52abdbcb84..e40543c78f 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -16,8 +16,10 @@ import os import platform +import re import sys import sysconfig +import tempfile import warnings from argparse import ArgumentParser from contextlib import contextmanager @@ -27,7 +29,7 @@ from pathlib import Path from shutil import copy, copytree, rmtree from subprocess import DEVNULL, CalledProcessError, check_output, run from textwrap import dedent -from typing import List +from typing import Sequence try: from packaging.requirements import Requirement @@ -68,13 +70,13 @@ def get_build_dir(build_dir, build_type): def clear_folder(folder_path): for item in os.listdir(folder_path): item_path = os.path.join(folder_path, item) - if os.path.isdir(item_path) and not os.path.islink(item_path): - rmtree(item_path) - else: - try: + try: + if os.path.isdir(item_path) and not os.path.islink(item_path): + rmtree(item_path) + else: os.remove(item_path) - except (OSError, IOError) as e: - print(f"Failed to remove {item_path}: {e}", file=sys.stderr) + except (OSError, IOError) as e: + print(f"Failed to remove {item_path}: {e}", file=sys.stderr) def sysconfig_scheme(override_vars=None): @@ -120,7 +122,8 @@ def create_venv(project_dir: Path): return venv_prefix -def setup_venv(project_dir: Path, requirements_file: Path, no_venv: bool): +def setup_venv(project_dir: Path, requirements_file: Path, + no_venv: bool) -> tuple[Path, Path]: """Creates/updates a venv and installs requirements. Args: @@ -279,6 +282,139 @@ def generate_fmha_cu(project_dir, venv_python): os.chdir(project_dir) +def create_cuda_stub_links(cuda_stub_dir: str, missing_libs: list[str]) -> str: + """ + Creates symbolic links for CUDA stub libraries in a temporary directory. + + Args: + cuda_stub_dir (str): Path to the directory containing CUDA stubs. + missing_libs: Versioned names of the missing libraries. + + Returns: + str: Path to the temporary directory where links were created. + """ + cuda_stub_path = Path(cuda_stub_dir) + if not cuda_stub_path.exists(): + raise RuntimeError( + f"CUDA stub directory '{cuda_stub_dir}' does not exist.") + + # Create a temporary directory for the symbolic links + temp_dir = tempfile.mkdtemp(prefix="cuda_stub_links_") + temp_dir_path = Path(temp_dir) + + version_pattern = r'\.\d+' + for missing_lib in filter(lambda x: re.search(version_pattern, x), + missing_libs): + # Define `so` as the first part of `missing_lib` with trailing '.' and digits removed + so = cuda_stub_path / re.sub(version_pattern, '', missing_lib) + so_versioned = temp_dir_path / missing_lib + + # Check if the library exists in the original directory + if so.exists(): + try: + # Create the symbolic link in the temporary directory + so_versioned.symlink_to(so) + except OSError as e: + # Clean up the temporary directory on error + rmtree(temp_dir) + raise RuntimeError( + f"Failed to create symbolic link for '{missing_lib}' in temporary directory '{temp_dir}': {e}" + ) + else: + warnings.warn( + f"Warning: Source library '{so}' does not exist and was skipped." + ) + + # Return the path to the temporary directory where the links were created + return str(temp_dir_path) + + +def check_missing_libs(so_prefix: str) -> list[str]: + result = build_run(f"ldd {so_prefix}.cpython*.so", + capture_output=True, + text=True) + missing = [] + for line in result.stdout.splitlines(): + if "not found" in line: + lib_name = line.split()[ + 0] # Extract the library name before "=> not found" + if lib_name not in missing: + missing.append(lib_name) + return missing + + +def generate_python_stubs_linux(binding_type: str, venv_python: Path, + deep_ep: bool): + is_nanobind = binding_type == "nanobind" + if is_nanobind: + build_run(f"\"{venv_python}\" -m pip install nanobind") + build_run(f"\"{venv_python}\" -m pip install pybind11-stubgen") + + env_stub_gen = os.environ.copy() + cuda_home_dir = env_stub_gen.get("CUDA_HOME") or env_stub_gen.get( + "CUDA_PATH") or "/usr/local/cuda" + missing_libs = check_missing_libs("bindings") + cuda_stub_dir = f"{cuda_home_dir}/lib64/stubs" + + if missing_libs and Path(cuda_stub_dir).exists(): + # Create symbolic links for the CUDA stubs + link_dir = create_cuda_stub_links(cuda_stub_dir, missing_libs) + ld_library_path = env_stub_gen.get("LD_LIBRARY_PATH") + env_stub_gen["LD_LIBRARY_PATH"] = ":".join( + filter(None, [link_dir, cuda_stub_dir, ld_library_path])) + else: + link_dir = None + + try: + if is_nanobind: + build_run(f"\"{venv_python}\" -m nanobind.stubgen -m bindings -O .", + env=env_stub_gen) + else: + build_run( + f"\"{venv_python}\" -m pybind11_stubgen -o . bindings --exit-code", + env=env_stub_gen) + build_run( + f"\"{venv_python}\" -m pybind11_stubgen -o . deep_gemm_cpp_tllm --exit-code", + env=env_stub_gen) + if deep_ep: + build_run( + f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code", + env=env_stub_gen) + finally: + if link_dir: + rmtree(link_dir) + + +def generate_python_stubs_windows(binding_type: str, venv_python: Path, + pkg_dir: Path, lib_dir: Path): + if binding_type == "nanobind": + print("Windows not yet supported for nanobind stubs") + exit(1) + else: + build_run(f"\"{venv_python}\" -m pip install pybind11-stubgen") + stubgen = "stubgen.py" + stubgen_contents = """ + # Loading torch, trt before bindings is required to avoid import errors on windows. + # isort: off + import torch + import tensorrt as trt + # isort: on + import os + import platform + + from pybind11_stubgen import main + + if __name__ == "__main__": + # Load dlls from `libs` directory before launching bindings. + if platform.system() == "Windows": + os.add_dll_directory(r\"{lib_dir}\") + main() + """.format(lib_dir=lib_dir) + (pkg_dir / stubgen).write_text(dedent(stubgen_contents)) + build_run(f"\"{venv_python}\" {stubgen} -o . bindings") + (pkg_dir / stubgen).unlink() + + def main(*, build_type: str = "Release", generator: str = "", @@ -286,7 +422,7 @@ def main(*, dist_dir: Path = None, cuda_architectures: str = None, job_count: int = None, - extra_cmake_vars: List[str] = list(), + extra_cmake_vars: Sequence[str] = tuple(), extra_make_targets: str = "", trt_root: str = '/usr/local/tensorrt', nccl_root: str = None, @@ -361,7 +497,7 @@ def main(*, if on_windows: # Windows does not support multi-device currently. - extra_cmake_vars.extend(["ENABLE_MULTI_DEVICE=0"]) + extra_cmake_vars = list(extra_cmake_vars) + ["ENABLE_MULTI_DEVICE=0"] # The Ninja CMake generator is used for our Windows build # (Easier than MSBuild to make compatible with our Docker image) @@ -703,81 +839,14 @@ def main(*, dirs_exist_ok=True) if not skip_stubs: - with working_directory(project_dir): - if binding_type == "nanobind": - build_run(f"\"{venv_python}\" -m pip install nanobind") - else: - build_run( - f"\"{venv_python}\" -m pip install pybind11-stubgen") with working_directory(pkg_dir): if on_windows: - if binding_type == "nanobind": - print("Windows not yet supported for nanobind stubs") - exit(1) - else: - stubgen = "stubgen.py" - stubgen_contents = """ - # Loading torch, trt before bindings is required to avoid import errors on windows. - # isort: off - import torch - import tensorrt as trt - # isort: on - import os - import platform - - from pybind11_stubgen import main - - if __name__ == "__main__": - # Load dlls from `libs` directory before launching bindings. - if platform.system() == "Windows": - os.add_dll_directory(r\"{lib_dir}\") - main() - """.format(lib_dir=lib_dir) - (pkg_dir / stubgen).write_text(dedent(stubgen_contents)) - build_run(f"\"{venv_python}\" {stubgen} -o . bindings") - (pkg_dir / stubgen).unlink() - else: - env_ld = os.environ.copy() - - new_library_path = "/usr/local/cuda/compat:/usr/local/cuda/compat/lib:/usr/local/cuda/compat/lib.real" - if 'LD_LIBRARY_PATH' in env_ld: - new_library_path += f":{env_ld['LD_LIBRARY_PATH']}" - - result = build_run("find /usr -name *libnvidia-ml.so*", - capture_output=True, - text=True) - assert result.returncode == 0, f"Failed to run find *libnvidia-ml.so*: {result.stderr}" - - # Build containers only contain stub version of libnvidia-ml.so and not the real version. - # If real version not in system, we need to create symbolic link to stub version to prevent import errors. - if "libnvidia-ml.so.1" not in result.stdout: - if "libnvidia-ml.so" in result.stdout: - line = result.stdout.splitlines()[0] - path = os.path.dirname(line) - new_library_path += f":{path}" - build_run(f"ln -s {line} {path}/libnvidia-ml.so.1") - else: - print( - f"Failed to find libnvidia-ml.so: {result.stderr}", - file=sys.stderr) - exit(1) - - env_ld["LD_LIBRARY_PATH"] = new_library_path - if binding_type == "nanobind": - build_run( - f"\"{venv_python}\" -m nanobind.stubgen -m bindings -O .", - env=env_ld) - else: - build_run( - f"\"{venv_python}\" -m pybind11_stubgen -o . bindings --exit-code", - env=env_ld) - if deep_ep_cuda_architectures: - build_run( - f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code", - env=env_ld) - build_run( - f"\"{venv_python}\" -m pybind11_stubgen -o . deep_gemm_cpp_tllm --exit-code", - env=env_ld) + generate_python_stubs_windows(binding_type, venv_python, + pkg_dir, lib_dir) + else: # on linux + generate_python_stubs_linux( + binding_type, venv_python, + bool(deep_ep_cuda_architectures)) if not skip_building_wheel: if dist_dir is None: diff --git a/setup.py b/setup.py index d3293c4bee..b16d9ce470 100644 --- a/setup.py +++ b/setup.py @@ -120,14 +120,14 @@ package_data += [ ] -def download_precompiled(workspace: str) -> str: +def download_precompiled(workspace: str, version: str) -> str: import glob import subprocess from setuptools.errors import SetupError cmd = [ - "pip", "download", f"tensorrt_llm={get_version()}", + "python3", "-m", "pip", "download", f"tensorrt_llm=={version}", f"--dest={workspace}", "--no-deps", "--extra-index-url=https://pypi.nvidia.com" ] @@ -201,17 +201,18 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str], wheel.extract(file) -use_precompiled: bool = os.getenv("TRTLLM_USE_PRECOMPILED") == "1" -precompiled_location: str = os.getenv("TRTLLM_PRECOMPILED_LOCATION") - -if precompiled_location: - use_precompiled = True +precompiled: str | None = os.getenv("TRTLLM_USE_PRECOMPILED") +precompiled_location: str | None = os.getenv("TRTLLM_PRECOMPILED_LOCATION") +use_precompiled: bool = (precompiled is not None + and precompiled != "0") or (precompiled_location + is not None) if use_precompiled: from tempfile import TemporaryDirectory with TemporaryDirectory() as tempdir: if not precompiled_location: - precompiled_location = download_precompiled(tempdir) + version = precompiled if precompiled != "1" else get_version() + precompiled_location = download_precompiled(tempdir, version) extract_from_precompiled(precompiled_location, package_data, tempdir) sanity_check() diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index f54026a8cb..5cffa98546 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -33,6 +33,7 @@ import sys # otherwise `MemoryError: std::bad_alloc` pattern error will be raised. import xgrammar # noqa +import tensorrt_llm._torch.models as torch_models import tensorrt_llm.functional as functional import tensorrt_llm.math_utils as math_utils import tensorrt_llm.models as models @@ -82,6 +83,7 @@ __all__ = [ 'default_trtnet', 'precision', 'net_guard', + 'torch_models', 'Network', 'Mapping', 'MnnvlMemory', diff --git a/tensorrt_llm/_mnnvl_utils.py b/tensorrt_llm/_mnnvl_utils.py index 6a31ee9f10..39f6deac4c 100644 --- a/tensorrt_llm/_mnnvl_utils.py +++ b/tensorrt_llm/_mnnvl_utils.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import ctypes +import os import platform import sys from dataclasses import dataclass @@ -114,9 +116,19 @@ class MnnvlMemory: location.id = dev_id allocation_prop = cuda.CUmemAllocationProp() allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED - allocation_prop.requestedHandleTypes = ( - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ) + + # TODO: We differentiate FABRIC for GB200 (aarch64) and POSIX_FILE_DESCRIPTOR for BB200 (x86_64). + # May need to find a better way to handle this. + arch = platform.machine().lower() + is_on_aarch64 = "aarch64" in arch + if is_on_aarch64: + allocation_prop.requestedHandleTypes = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ) + else: + allocation_prop.requestedHandleTypes = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ) allocation_prop.location = location return allocation_prop @@ -178,10 +190,48 @@ class MnnvlMemory: ) exported_fabric_handle = _check_cu_result( cuda.cuMemExportToShareableHandle( - allocated_mem_handle, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, 0 + allocated_mem_handle, allocation_prop.requestedHandleTypes, 0 ) ) - all_handles_data = comm.allgather(exported_fabric_handle.data) + if ( + allocation_prop.requestedHandleTypes + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): + all_handles_data = comm.allgather(exported_fabric_handle.data) + else: + all_handles_data = comm.allgather(exported_fabric_handle) + all_pids = comm.allgather(os.getpid()) + libc = ctypes.CDLL(None, use_errno=True) + syscall = libc.syscall + SYS_pidfd_open = 434 + SYS_pidfd_getfd = 438 + pidfds = [] + for i, pid in enumerate(all_pids): + pidfd = syscall(SYS_pidfd_open, pid, 0) + if pidfd < 0: + err = ctypes.get_errno() + raise RuntimeError( + f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}" + ) + pidfds.append(pidfd) + + remote_fds = [] + for i, (pidfd, fd) in enumerate(zip(pidfds, all_handles_data)): + remote_fd = syscall(SYS_pidfd_getfd, pidfd, fd, 0) + if remote_fd < 0: + err = ctypes.get_errno() + error_msg = f"pidfd_getfd(pidfd={pidfd}, fd={fd}) failed with errno {err}: {os.strerror(err)}." + if err == 1: # EPERM + error_msg += ( + " Permission denied. If running in a container, try adding --cap-add=SYS_PTRACE " + "to your docker run command." + ) + else: + error_msg += " This may be due to kernel version (requires Linux 5.6+)." + raise RuntimeError(error_msg) + remote_fds.append(remote_fd) + + all_handles_data = remote_fds # all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501 # can use buf = memoryview(data) to import if using plain buffer for data. @@ -205,7 +255,7 @@ class MnnvlMemory: # Fabric memory mapping imported_mem_handle = _check_cu_result( cuda.cuMemImportFromShareableHandle( - remote_handle_data, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + remote_handle_data, allocation_prop.requestedHandleTypes ) ) mem_handles[i] = imported_mem_handle @@ -279,13 +329,11 @@ class MnnvlMemory: @staticmethod def supports_mnnvl() -> bool: # TODO: - # We check if it is an aarch64 platform and has all NVLink up now. + # We check if it has all NVLink up now. # But it is not equivalent to MNNVL support. # May need better support check. - arch = platform.machine().lower() - is_on_aarch64 = "aarch64" in arch support_nvlink_and_all_up = MnnvlMemory.support_nvlink(True) - return is_on_aarch64 and support_nvlink_and_all_up + return support_nvlink_and_all_up @dataclass diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 4e860f6abb..8bd4c49bbc 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -342,6 +342,7 @@ class PositionalEmbedder(Protocol): class RopeParams: dim: int = 0 theta: float = 10000.0 + alpha: float = 1.0 scale_type: RotaryScalingType = RotaryScalingType.none scale: float = 1.0 low_freq_factor: float = 1.0 @@ -357,6 +358,7 @@ class RopeParams: short_factor: Optional[Tuple[float]] = None long_factor: Optional[Tuple[float]] = None max_seq_len: Optional[int] = None + duplicate_data: bool = True @staticmethod def from_config(config) -> "RopeParams": @@ -383,6 +385,7 @@ class RopeParams: rope_params.scale_type = RotaryScalingType.none rope_params.scale = 1.0 if rope_scaling is not None: + rope_params.alpha = rope_scaling.get("alpha", 1.0) rotary_scaling_type = rope_scaling.get( "type", None) or rope_scaling.get("rope_type") rope_params.scale_type = RotaryScalingType.from_string( @@ -440,6 +443,7 @@ class RopeParams: self.beta_slow, self.mscale, self.mscale_all_dim, + self.duplicate_data, ) elif self.scale_type == RotaryScalingType.longrope: rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope( @@ -460,6 +464,7 @@ class RopeParams: self.scale_type, rope_scaling_config={ "factor": self.scale, + "alpha": self.alpha, "low_freq_factor": self.low_freq_factor, "high_freq_factor": self.high_freq_factor, "original_max_position_embeddings": diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index a833515020..b7a8b5a6fc 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union import torch from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.bindings.internal import thop from tensorrt_llm.functional import AttentionMaskType from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig @@ -182,6 +183,7 @@ class TrtllmAttentionWrapper: spec_decoding_position_offsets: Optional[torch.Tensor] = None, spec_decoding_packed_mask: Optional[torch.Tensor] = None, spec_decoding_generation_lengths: Optional[torch.Tensor] = None, + attention_sinks: Optional[torch.Tensor] = None, **kwargs, ): """ @@ -217,6 +219,7 @@ class TrtllmAttentionWrapper: mla_context_paged_kv (torch.Tensor): The paged KV cache for MLA context, for kv cache reuse/chunked context. mla_context_kv_cache_block_offsets (torch.Tensor): The block offsets for the paged KV cache for MLA context, for kv cache reuse/chunked context. softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum) + attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU. """ self.layer_idx = layer_idx self.tokens_per_block = tokens_per_block @@ -253,6 +256,7 @@ class TrtllmAttentionWrapper: self.mla_context_paged_kv = mla_context_paged_kv self.mla_context_kv_cache_block_offsets = mla_context_kv_cache_block_offsets self.softmax_stats_tensor = softmax_stats_tensor + self.attention_sinks = attention_sinks if max_sequence_length > self.rope_params.max_positions: self.rope_params.max_positions = max_sequence_length @@ -416,7 +420,7 @@ class TrtllmAttentionWrapper: self.spec_decoding_position_offsets, self.spec_decoding_packed_mask ] - torch.ops.trtllm.attention_inplace( + thop.attention( q, k, v, @@ -442,6 +446,7 @@ class TrtllmAttentionWrapper: self.latent_cache, self.q_pe, self.block_ids_per_seq, + self.attention_sinks, is_fused_qkv, update_kv_cache, self.predicted_tokens_per_seq, @@ -1102,6 +1107,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): enable_attn_nvfp4_output: bool = True, output: Optional[torch.Tensor] = None, output_sf: Optional[torch.Tensor] = None, + attention_sinks: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: assert isinstance( @@ -1175,6 +1181,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): spec_decoding_packed_mask=metadata.spec_decoding_packed_mask, spec_decoding_generation_lengths=metadata. spec_decoding_generation_lengths, + attention_sinks=attention_sinks, ) out_dtype = None if out_scale is not None: diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index 0b309ae2bf..c2081e00df 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -162,7 +162,7 @@ class CapturedGraph(nn.Module): # copy inputs to input buffers for i, input_tensor in enumerate(args_batched): - self._input_buffers[i][: input_tensor.shape[0]] = input_tensor + self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True) # run forward pass via graph self.graphs[combined_shape].replay() diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index af6f130cef..f7ad7934a9 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -19,10 +19,6 @@ transforms: stage: post_export cleanup_input_constraints: stage: post_export - quantize: - stage: pattern_matcher - quantize_moe: - stage: pattern_matcher match_repeat_kv: stage: pattern_matcher match_eager_attention: @@ -31,3 +27,35 @@ transforms: stage: pattern_matcher match_attention_layout: stage: pattern_matcher + match_moe_pattern: + stage: pattern_matcher + match_rope_pattern: + stage: pattern_matcher + match_rope_layout: + stage: pattern_matcher + eliminate_redundant_transposes: + stage: pattern_matcher + # TODO (lucaslie): let's move this to perf optimization once TP sharding is improved + # see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528 + optimize_rope: + stage: pattern_matcher + quantize_from_config: + stage: pattern_matcher + quantize_from_graph: + stage: pattern_matcher + quantize_moe: + stage: pattern_matcher + # TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config. + detect_column_row_shard: + stage: sharding + simple_shard_only: false + detect_ep_shard: + stage: sharding + detect_dp_bmm_shard: + stage: sharding + # TODO: (hg) need to ensure run_shape_prop after sharding. + sharding_transform_executor: + stage: sharding + run_shape_prop: true + load_weights: + stage: weight_load diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 13c91652bf..d486d93b83 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -18,6 +18,8 @@ from torch._ops import OpOverloadPacket from torch.export import Dim from torch.fx import Node +from tensorrt_llm._utils import nvtx_range + @dataclass class CacheConfig: @@ -87,11 +89,13 @@ class SequenceInfo: # Similarly, if a batch is composed of generate-only requests, # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens). max_num_tokens: Optional[int] = None + # device is the device on which the sequence info is stored. + device: str = "cuda" ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP ################# # input_ids MUST ALWAYS BE THE FIRST FIELD - input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int)) - position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long)) + input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) + position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long)) seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) @@ -110,24 +114,44 @@ class SequenceInfo: # NOTE (lucaslie): WAR to address issue when using flashinfer attention with # (max_batch_size, max_seq_len) input in trtllm runtime. # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 - max_seq_len_adjusted = self.max_seq_len + 1 + self.max_seq_len_adjusted = self.max_seq_len + 1 if self.max_num_tokens is None or self.max_num_tokens < 1: - self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted + self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted # if the provided max_num_tokens is less than the max_batch_size * max_seq_len, # we use the provided max_num_tokens to calculate the number of pages - total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted) + total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted) # Num pages can not be less than max_batch_size. self._num_pages = max( self.max_batch_size, (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), ) - self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) - self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) - self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) - self.input_pos = torch.empty_like(self.seq_len) - self.cache_loc = torch.empty(self.num_pages, dtype=torch.int) - self.pages_per_seq = torch.empty_like(self.seq_len) + # Ensure that the device is set before initializing the tensors. + self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) + + # Consumers of the sequence info args require input_ids and position_ids to be truncated. + # We maintain a full version of the input_ids and position_ids to avoid overheads of tensor + # creation in every forward pass. + self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids_full = torch.zeros( + self.max_num_tokens, dtype=torch.long, device=self.device + ) + + self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device) + self.input_pos = torch.empty_like(self.seq_len, device=self.device) + + # Allocated host tensors for sequence lengths and input positions so that + # position_ids calculation can be done on host. + self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int) + self.input_pos_host = torch.empty_like(self.seq_len_host) + + self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device) + self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device) + + self.previous_batch_indices_cuda = torch.empty( + self.max_num_tokens, dtype=torch.long, device=self.device + ) assert self.num_pages >= self.max_batch_size, ( "num_pages must be greater than max_batch_size" ) @@ -140,13 +164,12 @@ class SequenceInfo: # indicator if extra args are activated that are needed for cached attention backends self._is_cached_attn = False + # total number of tokens in the current batch + self.num_tokens: int = 0 + # call reset once to initialize the tensors self.reset() - @property - def device(self) -> torch.device: - return self.input_pos.device - @property def args(self) -> Tuple[torch.Tensor, ...]: args = [] @@ -156,11 +179,14 @@ class SequenceInfo: args.append(val) if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn: break + return tuple(args) @property def _num_uncached_attn_args(self) -> int: - """Return the number of original graph arguments expected by the model.""" + """Return the number of original graph arguments expected by the model. + This is 2 because we have input_ids and position_ids as the original graph arguments. + """ return 2 @property @@ -185,7 +211,7 @@ class SequenceInfo: dynamic_shapes = ({}, {}) if self.max_batch_size > 1: dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size) - dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len) + dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted) # set up shape for position_ids (same as input_ids) dynamic_shapes[1].update(dynamic_shapes[0]) # set up shape for extra args @@ -204,7 +230,7 @@ class SequenceInfo: @property def input_positions(self) -> List[int]: - return self.input_pos[: self.num_sequences].tolist() + return self.input_pos_host[: self.num_sequences].tolist() @property def is_generate(self) -> bool: @@ -334,14 +360,19 @@ class SequenceInfo: """ # reset input_pos self.input_pos.zero_() + self.input_pos_host.zero_() # set a dummy sequence corresponding to a generate-only batch (will also reset position_ids) - self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int)) + self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) # reset cache information self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) self.pages_per_seq.fill_(1) + # let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens) + self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) + def set_example_sequence(self) -> None: """Set an example sequence useful for testing and export purposes.""" self.reset() @@ -352,7 +383,7 @@ class SequenceInfo: dtype=torch.int, device=self.device, ) - self.nest_sequences(input_ids) + self.nest_sequences(input_ids, allow_realloc=True) # unflatten if we are not yet using cached+flattened attention if not self._is_cached_attn: @@ -370,7 +401,7 @@ class SequenceInfo: device=self.device, ) self.pages_per_seq.fill_(seq_len // self.page_size) - self.nest_sequences(input_ids) + self.nest_sequences(input_ids, allow_realloc=True) def set_generate_only_batch(self) -> None: """Set an example sequence for generate-only batch. @@ -379,32 +410,96 @@ class SequenceInfo: mode. So we don't need to do anything mode-specific here. """ self.reset() - self.nest_sequences([[1]] * self.max_batch_size) - - def _update_position_ids(self) -> None: - # set new position_ids as new tensor from input_pos and seq_len via torch.arange - position_ids_list = [ - num - for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths) - for num in range(in_pos, in_pos + seq_len) - ] - self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device) + self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) + def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor: # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] if self.is_generate: - self.position_ids = self.position_ids.view(-1, 1) + return tensor.view(-1, 1, *tensor.shape[1:]) else: - self.position_ids = self.position_ids.view(1, -1) + return tensor.view(1, -1, *tensor.shape[1:]) - def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: + @nvtx_range("ad_update_position_ids") + def _update_position_ids(self, allow_realloc: bool = False) -> None: + # set new position_ids from input_pos and seq_len + # Make sure this is done on host to avoid host-device copies. + with nvtx_range("prepare_list"): + # Optimize for the common case where all seq_len values are 1 (generation mode) + if torch.all(self.seq_len_host == 1): + # Fast path: when all seq_len are 1, position_ids is just input_pos_host + position_ids_host = ( + self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory() + ) + else: + # General case - can probably be optimized too, but overall impact will be minor. + position_ids_list = [] + for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host): + position_ids_list.extend(range(in_pos, in_pos + seq_len)) + position_ids_host = torch.tensor( + position_ids_list, dtype=torch.long, pin_memory=True + ) + with nvtx_range("copy_to_device"): + if allow_realloc: + # Create a new position_ids tensor on the device + self.position_ids = position_ids_host.to(self.device).clone() + else: + self.position_ids_full = self.position_ids_full.flatten() + self.position_ids_full[: self.num_tokens].copy_( + position_ids_host, non_blocking=True + ) + with nvtx_range("maybe_reshape"): + self.position_ids = self.maybe_reshape_for_generate( + self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens] + ) + + @nvtx_range("ad_update_sequence_lengths") + def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None: + self._sequence_lengths = sequence_lengths + self.num_tokens = sum(self._sequence_lengths) + self.seq_len.zero_() + self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True) + self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True) + + def update_input_ids_with_new_tokens( + self, new_tokens: torch.Tensor, previous_batch_indices: List[int] + ) -> None: + """Update the input_ids with new tokens. + + This function will update the input_ids with new tokens and previous batch indices. + """ + # 1) flatten once + original_shape = self.input_ids.shape + flat = self.input_ids.flatten() + + # copy indices to the GPU + host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True) + idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)] + idx.copy_(host_idx, non_blocking=True) + + # sort them so that masked_scatter_ lines up correctly + idx, _ = idx.sort() + + # gather the exact values you want to write + src = new_tokens[0, idx, 0] + + # in‐place fill every slot where flat == -1 with src, in order + flat.masked_scatter_(flat == -1, src) + + # 4) reshape back + self.input_ids = flat.view(original_shape) + + @nvtx_range("ad_nest_sequences") + def nest_sequences( + self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False + ) -> None: """Create and store a flattened list of input_ids from the provided list of sequences. + When allow_realloc is True, the input_ids will be reallocated on the device. This i/f will also update any relevant sequence information. """ # set new sequence lengths - seq_lens = [len(ids) for ids in input_ids] - self.seq_len.zero_() - self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True) + self._update_sequence_lengths([len(ids) for ids in input_ids]) + # We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int # set new input_ids as new tensor from flattened input_ids @@ -413,49 +508,57 @@ class SequenceInfo: for lst in input_ids for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) ] - self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device) + input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True) - # set derivative properties - self._sequence_lengths = seq_lens - - # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] - if self.is_generate: - self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:]) + if allow_realloc: + self.input_ids = input_ids_host.to(self.device).clone() else: - self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:]) + self.input_ids_full = self.input_ids_full.flatten() + self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True) + self.input_ids = self.maybe_reshape_for_generate( + self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens] + ) # update position_ids - self._update_position_ids() + self._update_position_ids(allow_realloc=allow_realloc) + @nvtx_range("ad_unnest_sequences") def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.sequence_lengths)) + @nvtx_range("ad_update_pos") def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None: """Update the starting position for each sequence in the cache. If ``reset=True`, ``input_pos`` will be reset to zero before updating. """ if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len, dtype=torch.int) + seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True) bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size if reset: - self.input_pos[:bs] = seq_len.to(self.device) + self.input_pos_host[:bs].copy_(seq_len, non_blocking=True) else: - self.input_pos[:bs] += seq_len.to(self.device) + self.input_pos_host[:bs] += seq_len # update position_ids self._update_position_ids() + self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True) + @nvtx_range("ad_assign_cache_loc") def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None: """Set the cache location and pages_per_seq tensors from page assignments.""" cache_loc_flat = torch.tensor( - [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int + [p_idx for pages in page_assignments for p_idx in pages], + dtype=torch.int, + pin_memory=True, ) self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) - pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int) + pages_per_seq = torch.tensor( + [len(p) for p in page_assignments], dtype=torch.int, pin_memory=True + ) self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 61337ae3f4..812dfea29c 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -3,9 +3,11 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Type, Union import torch -from pydantic import Field, ValidationInfo, field_validator, model_validator +from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict +from tensorrt_llm.models.modeling_utils import QuantConfig + from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, _ParallelConfig from ...llmapi.utils import get_type_repr from .models import ModelFactory, ModelFactoryRegistry @@ -259,6 +261,18 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): ) garbage_collection_gen0_threshold: int = Field(default=20000, description="See TorchLlmArgs.") + _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) + + @property + def quant_config(self) -> QuantConfig: + if self._quant_config is None: + self._quant_config = QuantConfig() + return self._quant_config + + @quant_config.setter + def quant_config(self, value: QuantConfig): + self._quant_config = value + ### VALIDATION ################################################################################# @field_validator("build_config", mode="before") @classmethod diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index fc37c1e557..ec1da12bc9 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -1,6 +1,5 @@ """Interface to initialize and load HF models.""" -import json import os import types from contextlib import contextmanager, nullcontext @@ -31,6 +30,7 @@ from ..custom_ops.attention_interface import CacheConfig from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger from .factory import ModelFactory, ModelFactoryRegistry +from .quant_config_reader import QuantConfigReader, QuantConfigReaderRegistry @contextmanager @@ -73,26 +73,16 @@ class AutoModelForCausalLMFactory(ModelFactory): _model_defaults = { "use_cache": False, - "max_position_embeddings": 1024, } - def _get_max_position_embeddings_config(self) -> Dict[str, Any]: - """Get the max position embeddings config for the model.""" - return { - "max_position_embeddings": self.max_seq_len, - } - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - self._quant_config: Optional[Dict] = None - + self._quant_config_reader: QuantConfigReader | None = None # Ingest defaults for tokenizer and model kwargs self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs) self.model_kwargs = deep_merge_dicts( self._model_defaults, self.model_kwargs, - self._get_max_position_embeddings_config(), ) # special handling for torch_dtype in model_kwargs since HF does not correctly update @@ -156,9 +146,6 @@ class AutoModelForCausalLMFactory(ModelFactory): def _build_model(self, device: DeviceLikeType) -> nn.Module: """Build the model on the desired device.""" - # We only support fp16 to fp4 conversion. - if self._quant_config and self._quant_config.get("quant_algo", None) == "NVFP4": - self.model_kwargs["torch_dtype"] = torch.half # NOTE (lucaslie): HF doesn't recursively update nested PreTrainedConfig objects. Instead, # the entire subconfig will be overwritten. @@ -178,23 +165,27 @@ class AutoModelForCausalLMFactory(ModelFactory): model.forward = types.MethodType(self._simple_forward, model) model.eval() + return model def get_quant_config(self) -> Dict: - return self._quant_config or {} + """Returns the quantization config for this model or an empty dict if not quantized.""" + if self._quant_config_reader is not None: + return self._quant_config_reader.get_config() + return {} def get_cache_config(self): - """Setup cache information based on quantization information.""" - if self._quant_config is not None and "kv_cache_quant_algo" in self._quant_config.keys(): - kv_cache_format = self._quant_config.get("kv_cache_quant_algo", None) - if kv_cache_format is not None: - assert kv_cache_format == "FP8", ( - f"KV cache quantization format {kv_cache_format} is not supported." - ) - kv_cache_dtype = torch.float8_e4m3fn if kv_cache_format is not None else None - else: - kv_cache_dtype = None - return CacheConfig(dtype=kv_cache_dtype) + """Return kv cache dtype configuration.""" + if not self._quant_config_reader: + return CacheConfig(dtype=None) + + kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype") + torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None + assert torch_dtype in (torch.float8_e4m3fn, None), ( + f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported." + ) + + return CacheConfig(dtype=torch_dtype) def init_tokenizer(self) -> Optional[Any]: """Initialize the tokenizer—either a custom name or the model's default.""" @@ -325,44 +316,29 @@ class AutoModelForCausalLMFactory(ModelFactory): def _load_quantization_config(self, fetched_dir: str): """Load the quantization config from the model directory if not done already.""" - if self._quant_config is not None: + if self._quant_config_reader is not None: return + # TODO: specified by user or auto-detect + reader_cls = QuantConfigReaderRegistry.get("modelopt") + result = reader_cls.from_file(fetched_dir) + if result is None: + return + reader, extra_model_kwargs = result - assert self.model - hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json") - if os.path.exists(hf_quant_config_file): - with open(hf_quant_config_file, "r") as file: - quantization_config = json.load(file) - assert quantization_config.get("producer", {}).get("name", None) == "modelopt", ( - "Only support modelopt quantized checkpoint" - ) - self._quant_config = quantization_config.get("quantization", {}) - - # We do not quantize lm_head. - if "exclude_modules" not in self._quant_config: - self._quant_config["exclude_modules"] = ["lm_head"] + if reader is not None: + self._quant_config_reader = reader + self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs) @ModelFactoryRegistry.register("AutoModelForImageTextToText") class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): _model_defaults = { "use_cache": False, - "max_position_embeddings": 1024, "text_config": { - "max_position_embeddings": 1024, "use_cache": False, }, } - def _get_max_position_embeddings_config(self) -> Dict[str, Any]: - """Get the max position embeddings config for the model.""" - return { - "max_position_embeddings": self.max_seq_len, - "text_config": { - "max_position_embeddings": self.max_seq_len, - }, - } - @property def automodel_from_config(self): return AutoModelForImageTextToText.from_config diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py b/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py new file mode 100644 index 0000000000..3aecdc5ecc --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py @@ -0,0 +1,99 @@ +import types +from typing import Callable, Dict, Optional + +import torch +from transformers.models.auto.modeling_auto import AutoModelForCausalLM + + +def gpt_oss_attention( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +): + """GPT OSS Attention forward function rewritten to wrap attention as a custom op.""" + from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb + + # Add new parameters + sliding_window = getattr(self, "sliding_window", -1) # Default to -1 if not present + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # Apply Q, K, V projections (same as original) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + # Use original rope implementation + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Handle KV cache properly + if past_key_value is not None: + # Update KV cache - check if it has update method (modern cache objects) + if hasattr(past_key_value, "update"): + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + # Handle legacy tuple-based cache + if isinstance(past_key_value, tuple) and len(past_key_value) == 2: + past_key, past_value = past_key_value + key_states = torch.cat([past_key, key_states], dim=2) + value_states = torch.cat([past_value, value_states], dim=2) + + # Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim] + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + + # Get sinks parameter from model if available + sinks = None + if hasattr(self, "sinks"): + # If sinks is a model parameter, use it directly + sinks = self.sinks + + # Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim) + attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=True, + scale=self.scaling, + sinks=sinks, + sliding_window=sliding_window, + ) + + # Reshape back to original input shape + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + + +_from_config_original = AutoModelForCausalLM.from_config + +CUSTOM_MODULE_PATCHES: Dict[str, Callable] = { + "GptOssAttention": gpt_oss_attention, +} + + +def get_model_from_config_patched(config, **kwargs): + model = _from_config_original(config, **kwargs) + # Patch modules + for _, module in model.named_modules(): + if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys(): + # Replace the forward method + module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module) + + return model + + +AutoModelForCausalLM.from_config = get_model_from_config_patched diff --git a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py new file mode 100644 index 0000000000..01196bf578 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py @@ -0,0 +1,130 @@ +""" +Quantization Config Reader Registry. + +This module defines a registry system for parsing quantization configurations +from various sources (e.g., 'modelopt'). It enables extensible support for different +quantization producers by delegating parsing logic to dedicated subclasses. +""" + +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional, Tuple, Type + + +class QuantConfigReader(ABC): + """Base class for reading and parsing quantization config.""" + + def __init__(self): + self._quant_config: Optional[Dict] = {} + + def get_config(self) -> Dict: + """Return the parsed quantization config.""" + return self._quant_config + + @abstractmethod + def read_config(self, config: Dict) -> Dict: + """ + Parse and normalize a quantization config dictionary. + + Args: + config: The raw parsed JSON object. + + Returns: + A dictionary of extra model kwargs derived from the quantization config. + Implementations must also populate self._quant_config with the normalized + quantization config. + """ + pass + + @classmethod + @abstractmethod + def from_file(cls, file_path: str) -> Optional[Tuple["QuantConfigReader", Dict[str, Any]]]: + """ + Load and parse a quantization config file from disk. + + This method is implemented by each reader to handle loading and parsing logic. + + Args: + file_path: Path to the quant config JSON file. + + Returns: + A (reader, extra_model_kwargs) tuple, or None if the file doesn't exist. + """ + pass + + +class QuantConfigReaderRegistry: + _registry: Dict[str, Type[QuantConfigReader]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[QuantConfigReader]], Type[QuantConfigReader]]: + def inner(reader_cls: Type[QuantConfigReader]) -> Type[QuantConfigReader]: + cls._registry[name] = reader_cls + return reader_cls + + return inner + + @classmethod + def get(cls, name: str) -> Type[QuantConfigReader]: + if name not in cls._registry: + raise ValueError(f"QuantConfigReader for '{name}' not registered.") + return cls._registry[name] + + @classmethod + def has(cls, reader_cls: str) -> bool: + return reader_cls in cls._registry + + +@QuantConfigReaderRegistry.register("modelopt") +class ModelOPTQuantConfigReader(QuantConfigReader): + def read_config(self, config: Dict) -> Dict: + producer = config.get("producer", {}).get("name") + # sanity check + if producer != "modelopt": + raise ValueError(f"Expected producer 'modelopt', got '{producer}'") + + quant_config = config.get("quantization", {}) + # Inject default exclusion, add "model.embed_tokens" for "tie_word_embedding:true" case + quant_config.setdefault("exclude_modules", ["lm_head", "model.embed_tokens"]) + # Update dtype + if quant_config.get("quant_algo") == "NVFP4": + quant_config["torch_dtype"] = "float16" + + # Handle kv cache + kv_algo = quant_config.get("kv_cache_quant_algo") + if kv_algo: + if kv_algo != "FP8": + raise ValueError(f"KV cache quantization format {kv_algo} not supported.") + quant_config["kv_cache_dtype"] = "float8_e4m3fn" + + self._quant_config = quant_config + + extra_model_kwargs: Dict[str, Any] = {} + if quant_config.get("quant_algo", None) == "NVFP4": + extra_model_kwargs["torch_dtype"] = "float16" + + return extra_model_kwargs + + @classmethod + def from_file( + cls, ckpt_dir: str + ) -> Optional[Tuple["ModelOPTQuantConfigReader", Dict[str, Any]]]: + """ + Load and parse a modelopt-style quantization config from a checkpoint directory. + + Args: + ckpt_dir: Path to the root directory containing the checkpoint. + + Returns: + An initialized ModelOPTQuantConfigReader instance, or None if the file doesn't exist. + """ + quant_file = os.path.join(ckpt_dir, "hf_quant_config.json") + if not os.path.exists(quant_file): + return None + + with open(quant_file, "r") as f: + raw = json.load(f) + reader = cls() + extra_model_kwargs = reader.read_config(raw) + return reader, extra_model_kwargs diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index ff0fb204f1..fea836bda4 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -94,20 +94,21 @@ class ADEngine(ModelEngine): f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}, {max_beam_width=}" ) + # update device to contain the current default device if it's in cuda + device = torch.device(ad_config.device) + if device.type == "cuda" and device.index is None: + device = torch.device(f"cuda:{torch.cuda.current_device()}") + device = str(device) + # initialize seq info object seq_info = SequenceInfo( max_seq_len=max_seq_len, max_batch_size=max_batch_size, page_size=attn_page_size, max_num_tokens=max_num_tokens, + device=device, ) - # update device to contain the current default device if it's in cuda - device = torch.device(ad_config.device) - if device.type == "cuda" and device.index is None: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - device = str(device) - # construct inference optimizer build_and_optimize = InferenceOptimizer( factory=ad_config.create_factory(), ad_config=ad_config @@ -170,16 +171,12 @@ class ADEngine(ModelEngine): context_requests = scheduled_requests.context_requests gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens] - # new_tokens is a tensor on the device, we need to convert it to a list of lists. - # can we avoid this additional gpu->cpu transfer? - new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None - # info to be extracted input_ids: List[List[int]] = [] input_pos: List[int] = [] last_logit_only: List[bool] = [] page_assignments: List[List[int]] = [] - + previous_batch_indices: List[int] = [] # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence @@ -193,11 +190,13 @@ class ADEngine(ModelEngine): # TODO: we should also handle extend requests (for speculative decoding) here for request in gen_requests: # new_tokens are provided when the overlap scheduler is enabled. - if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None: + if new_tokens is None or request.is_dummy or request.py_batch_idx is None: input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)]) input_pos.append(request.max_beam_num_tokens - 1) else: - input_ids.append([new_tokens_list[request.py_batch_idx]]) + # insert a dummy token to indicate the new tokens + input_ids.append([-1]) + previous_batch_indices.append(request.py_batch_idx) input_pos.append(request.max_beam_num_tokens) request.py_batch_idx = request.seq_slot @@ -213,11 +212,15 @@ class ADEngine(ModelEngine): # update the sequence info object now si = self.cache_seq_interface.info - si.nest_sequences(input_ids) si.update_pos(input_pos, reset=True) si.assign_cache_loc(page_assignments) + si.nest_sequences(input_ids) + + if new_tokens is not None: + si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices) return last_logit_only + @nvtx_range("ad_compute_logits") def _compute_logits(self) -> List[torch.Tensor]: # run the model logits: torch.Tensor = self.model(*self.cache_seq_interface.args)[0] @@ -234,13 +237,13 @@ class ADEngine(ModelEngine): self, scheduled_requests: ScheduledRequests, resource_manager: ResourceManager, - new_tokens_device: Optional[torch.Tensor] = None, + new_tensors_device: Optional[torch.Tensor] = None, gather_context_logits: bool = False, cache_indirection_buffer: Optional[torch.Tensor] = None, ): """Run forward from scheduled requests; main entrypoint that gets called by the executor.""" # convert requests and store in sequence info object - new_tokens = getattr(new_tokens_device, "new_tokens", None) + new_tokens = getattr(new_tensors_device, "new_tokens", None) last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens) # compute all logits diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index dd5bc421bb..1087714177 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -15,6 +15,7 @@ from ..models.factory import ModelFactory from ..shim.interface import CachedSequenceInterface from ..transformations._graph import canonicalize_graph, lift_to_meta from ..utils.logger import ad_logger +from ..utils.sharding_utils import ShardingConfig class TransformError(Exception): @@ -47,6 +48,14 @@ class Stages(Enum): return NotImplemented +class SharedConfig(BaseModel): + """Global config shared between multiple transforms in the inference optimizer.""" + + sharding_config: ShardingConfig = Field(default_factory=ShardingConfig) + local_rank: int = Field(default=0) + world_size: int = Field(default=1) + + class TransformConfig(BaseModel): """A simple configuration class that can be extended by a transform for configurability.""" @@ -190,7 +199,11 @@ class BaseTransform(ABC): @final def __call__( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> GraphModule: """Apply the transform to the graph. @@ -198,6 +211,7 @@ class BaseTransform(ABC): gm: The graph module to apply the transform to. cm: The cached sequence interface defining the sequence interface. factory: The model factory used to build the model. + shared_config: Global info shared between multiple transforms. Returns: GraphModule: The transformed graph module. @@ -232,14 +246,14 @@ class BaseTransform(ABC): # run the transform in a error-handling wrapper if desired if self.config.skip_on_error: try: - gm, info = self._apply(gm, cm, factory) + gm, info = self._apply(gm, cm, factory, shared_config) except Exception as e: error_msg = f"Transform {t_name} failed" ad_logger.warning(f"{error_msg}: {e}") info = TransformInfo(skipped=True, num_matches=0) else: # handle this here normally to improve debugging and error message - gm, info = self._apply(gm, cm, factory) + gm, info = self._apply(gm, cm, factory, shared_config) # we cannot say it's clean if the previous wasn't clean even if this one is # create new info object with updated cleanup status @@ -346,7 +360,11 @@ class BaseTransform(ABC): @abstractmethod def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: """Apply the transform to the graph. diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 94da4dd514..c0c2d88d70 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -13,7 +13,13 @@ from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger from ...utils.node_utils import is_op from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern -from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) def _apply_pattern( @@ -325,7 +331,11 @@ class MatchRepeatKV(BaseTransform): """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: def register_repeat_kv(patterns: ADPatternMatcherPass): dummy_args = [ @@ -366,7 +376,11 @@ class MatchEagerAttention(BaseTransform): """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: def register_eager_attention(patterns: ADPatternMatcherPass): for pattern_config in _get_sfdp_patterns(): @@ -392,7 +406,11 @@ class MatchGroupedAttention(BaseTransform): """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: def register_grouped_attention(patterns: ADPatternMatcherPass): q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) @@ -478,7 +496,11 @@ class MatchAttentionLayout(BaseTransform): return MatchAttentionLayoutConfig def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: # Get attention layout from attention_op attention_layout = self.config.attention_op.get_attention_layout() diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py index 48a8accb20..8d99a27fcf 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py @@ -7,7 +7,13 @@ from torch.fx import GraphModule from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface -from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) class BuildModelConfig(TransformConfig): @@ -27,7 +33,11 @@ class BuildModel(BaseTransform): return BuildModelConfig def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: # build the model model = factory.build_model(self.config.device) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py b/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py index 1e5963505e..ec0e727d2f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py @@ -7,7 +7,7 @@ from torch.utils._sympy.value_ranges import ValueRanges from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface -from ..interface import BaseTransform, TransformInfo, TransformRegistry +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry # TODO (lucaslie): consider reconfiguring this transform to run before we switch to flattened @@ -22,7 +22,11 @@ class CleanupInputConstraints(BaseTransform): """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: graph: Graph = gm.graph input_node = graph.find_nodes(op="placeholder")[0] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py b/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py index 4b2abf3106..b579c6ea7c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py @@ -6,7 +6,7 @@ from torch.fx import GraphModule from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import is_op -from ..interface import BaseTransform, TransformInfo, TransformRegistry +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @TransformRegistry.register("cleanup_noop_add") @@ -22,7 +22,11 @@ class CleanupNoopAdd(BaseTransform): """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: num_matches = 0 for node in gm.graph.nodes: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py b/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py index 4b58520931..4618ab582e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py @@ -6,7 +6,7 @@ from torch.fx import GraphModule from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import is_op -from ..interface import BaseTransform, TransformInfo, TransformRegistry +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @TransformRegistry.register("cleanup_noop_slice") @@ -19,7 +19,11 @@ class CleanupNoopSlice(BaseTransform): """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: num_matches = 0 for node in gm.graph.nodes: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py b/tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py new file mode 100644 index 0000000000..66b4c7e28c --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py @@ -0,0 +1,124 @@ +"""Graph transformation to eliminate redundant transpose operations in the model graph. + +This transformation identifies and removes patterns where transpose operations with the same +dimensions are applied consecutively, which cancel each other out: +x = x.transpose(1, 2) +x = x.transpose(1, 2) +""" + +from typing import Set, Tuple + +import torch +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _is_transpose_op(node: Node) -> bool: + """Check if the node is a transpose operation.""" + return is_op(node, torch.ops.aten.transpose) + + +def _is_contiguous_op(node: Node) -> bool: + """Check if the node is a contiguous operation.""" + return is_op(node, torch.ops.aten.contiguous) + + +def _are_transpose_args_same(node1: Node, node2: Node) -> bool: + """Check if two transpose nodes have the same dimension arguments.""" + # Get the dimension arguments for both nodes + # Args structure: (input_tensor, dim1, dim2) + if len(node1.args) < 3 or len(node2.args) < 3: + return False + + dim1_node1, dim2_node1 = node1.args[1], node1.args[2] + dim1_node2, dim2_node2 = node2.args[1], node2.args[2] + + # Check if the dimensions are the same + return dim1_node1 == dim1_node2 and dim2_node1 == dim2_node2 + + +@TransformRegistry.register("eliminate_redundant_transposes") +class EliminateRedundantTransposes(BaseTransform): + """Eliminate redundant transpose operations in the graph. + + This transformation identifies pairs of consecutive transpose operations with + the same dimension arguments and removes both operations, as they cancel out. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + + # Find pairs of redundant transpose operations + nodes_to_eliminate: Set[Tuple[Node, Node]] = set() + + for t_node in gm.graph.nodes: + # check if there is a transpose operation + if not _is_transpose_op(t_node): + continue + + # check if it's already part of a pair + if any(t_node in pair for pair in nodes_to_eliminate): + continue + + # check if there is only one user + if len(t_node.users) > 1: + continue + + # check if the user is a contiguous operation + t_comp_node = list(t_node.users)[0] + + # check if the user is a contiguous operation + has_contiguous = False + while _is_contiguous_op(t_comp_node) and len(t_comp_node.users) == 1: + has_contiguous = True + t_comp_node = list(t_comp_node.users)[0] + + # check if the user is a transpose operation + if not _is_transpose_op(t_comp_node): + continue + + # check if the transpose operation has the same dimension arguments + if not _are_transpose_args_same(t_node, t_comp_node): + continue + + # add the pair to the set + nodes_to_eliminate.add((t_node, t_comp_node, has_contiguous)) + + # Eliminate redundant transpose pairs + for t_node, t_comp_node, has_contiguous in nodes_to_eliminate: + # Replace all uses of the second transpose with the input to the first transpose + original_input = t_node.args[0] + t_comp_node.replace_all_uses_with(original_input) + + # if there is a contiguous operation that we skipped, let add it after t_comp_node as new + # graph node that call contiguous on t_comp_node + if has_contiguous: + with graph.inserting_after(original_input): + new_contiguous_node = graph.call_function( + torch.ops.aten.contiguous.default, args=(original_input,) + ) + original_input.replace_all_uses_with(new_contiguous_node) + new_contiguous_node.replace_input_with(new_contiguous_node, original_input) + + # Clean up the graph + if nodes_to_eliminate: + gm.graph.eliminate_dead_code() + + info = TransformInfo( + skipped=False, + num_matches=len(nodes_to_eliminate), + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index bbe72650b4..d07ab02c62 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -8,7 +8,13 @@ from torch.fx import GraphModule from ...export import torch_export_to_gm from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface -from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) class ExportToGMConfig(TransformConfig): @@ -44,7 +50,11 @@ class ExportToGM(BaseTransform): return ExportToGMConfig def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: # at this point we assume the gm is just a dummy graph module assert len(gm.graph.nodes) == 0, "Expected empty graph module." diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py new file mode 100644 index 0000000000..8a395ea912 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -0,0 +1,523 @@ +from collections import defaultdict +from typing import Optional, Tuple + +import torch +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.cuda_mem_tracker import cuda_memory_tracker +from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op +from ...utils.quantization_utils import get_scales_and_type_from_node +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _insert_fused_moe_ops(gm: GraphModule) -> int: + fused_key_counter = 0 + graph = gm.graph + + for node in list(graph.nodes): + if not is_op(node, torch.ops.auto_deploy.torch_moe): + continue + + hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = node.args + + fused_w3_w1_experts = torch.stack( + [ + torch.cat( + [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2 + ) + for w1_node, w3_node in zip(w1_list, w3_list) + ], + dim=0, + ) + + fused_w2_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) + + new_key_w3_w1 = f"fused_moe_w3_w1_stacked_{fused_key_counter}" + new_key_w2 = f"fused_moe_w2_stacked_{fused_key_counter}" + fused_key_counter += 1 + param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts) + param_w2 = torch.nn.Parameter(fused_w2_experts) + gm.register_parameter(new_key_w3_w1, param_w3_w1) + gm.register_parameter(new_key_w2, param_w2) + + with graph.inserting_before(node): + new_node = graph.call_function( + # TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models + torch.ops.auto_deploy.trtllm_moe_fused, + args=( + hidden_states, + selected_experts, + routing_weights, + graph.get_attr(new_key_w3_w1), + graph.get_attr(new_key_w2), + ), + ) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + return fused_key_counter + + +def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]: + """ + Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following + each node's primary branch (recursively following the first Node argument). + + It first finds the LCA of the first two nodes and then + iteratively computes the LCA of the result with the next node, and so on. + + Returns: + The common ancestor Node if found, otherwise None. + """ + if not nodes: + return None + + def get_parent(node: Node) -> Optional[Node]: + """Return the first Node-valued argument for a given node, or None if not found.""" + for arg in node.args: + if isinstance(arg, Node): + return arg + return None + + def get_depth(node: Node) -> int: + """ + Recursively compute the depth of the node by following its primary branch. + Depth is defined as the number of steps to reach a node with no parent. + """ + parent = get_parent(node) + if parent is None: + return 0 + return 1 + get_depth(parent) + + def lca_two(a: Node, b: Node) -> Optional[Node]: + """ + Find the lowest common ancestor of two nodes by first equalizing their depth + and then moving upward until a common node is found. + """ + depth_a = get_depth(a) + depth_b = get_depth(b) + + # Equalize depths + while depth_a > depth_b: + a = get_parent(a) + depth_a -= 1 + while depth_b > depth_a: + b = get_parent(b) + depth_b -= 1 + + # Walk upward in lockstep + while a is not None and b is not None: + if a is b: + return a + a = get_parent(a) + b = get_parent(b) + return None + + # Iteratively compute the LCA across all nodes. + common = nodes[0] + for node in nodes[1:]: + common = lca_two(common, node) + if common is None: + return None + + return common + + +def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]: + """ + Given a linear op node, extract the input tensor node, weight tensor, + any quantization scales (if the op is quantized), and return a weight type. + + For a torch.ops.auto_deploy.torch_linear_simple.default op: + - Returns (input_node, weight, None, "simple") + + For a torch.ops.auto_deploy.torch_quant_fp8_linear op: + - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8") + For a torch.ops.auto_deploy.torch_quant_fp4_linear op: + - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4") + """ + input_node = linear_node.args[0] + if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple): + weight = linear_node.args[1] + return input_node, weight, None, "" + elif { + is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear) + or is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear), + }: + weight = linear_node.args[1] + scales, quant_type = get_scales_and_type_from_node(linear_node) + return input_node, weight, scales or {}, quant_type + + +def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node): + """ + Match the expert compute pattern between the given boundaries. + + The expert compute pattern corresponds to: + + (F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t() + + For each expert, the function extracts the input node from the w1 branch and + collects the weight parameters from three linear ops (w1, w3, and w2 branches). + + This function supports both: + - torch.ops.auto_deploy.torch_linear_simple.default ops, and + - torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales). + - torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales). + + Returns: + A tuple: + (pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) + + - pattern_input_nodes: List of input nodes (x) used for the expert compute. + - pattern_output_nodes: List of final expert output nodes (the linear op with weight w2). + - expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors. + - expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors + (empty if weight_type is "simple"). + - weight_type: "fp8" if FP8 ops were used, "simple" otherwise. + """ + pattern_input_nodes, pattern_output_nodes = [], [] + expert_weights = defaultdict(list) + expert_scales = defaultdict(list) + weight_type = "simple" # default + + nodes = list(start_boundary.graph.nodes) + region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)] + + for node in region_nodes: + # Accept both simple and quantized linear ops. + if not is_linear_op(node, include_quantization=True): + continue + + final_linear = node + if not final_linear.args or not isinstance(final_linear.args[0], Node): + continue + + mul_node = final_linear.args[0] + if not is_op(mul_node, torch.ops.aten.mul) or len(mul_node.args) < 2: + continue + + arg_a, arg_b = mul_node.args[:2] + silu_node = ( + arg_a + if is_op(arg_a, torch.ops.aten.silu) + else arg_b + if is_op(arg_b, torch.ops.aten.silu) + else None + ) + if silu_node is None: + continue + + if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)): + continue + linear_w1_node = silu_node.args[0] + + # The other branch should be a linear op (w3 branch). + linear_w3_node = arg_b if arg_a is silu_node else arg_a + if not is_linear_op(linear_w3_node, include_quantization=True): + continue + if not (linear_w1_node.args and linear_w3_node.args): + continue + + # Extract parameters from each linear op. + input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters( + linear_w1_node + ) + _, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node) + _, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear) + + if None in (weight_w1, weight_w3, weight_w2): + continue + + # Ensure the weight type is consistent across branches. + if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2: + continue + weight_type = wt_type_w1 + + pattern_input_nodes.append(input_node_w1) + pattern_output_nodes.append(final_linear) + expert_weights["w1"].append(weight_w1) + expert_weights["w3"].append(weight_w3) + expert_weights["w2"].append(weight_w2) + + # TODO: sanity check that all experts have same weight type + if weight_type == "fp8": + expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) + expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) + expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) + expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) + expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) + expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) + elif weight_type == "fp4": + expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) + expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) + expert_scales["w1_alpha"].append(quant_params_w1["alpha"]) + expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) + expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) + expert_scales["w3_alpha"].append(quant_params_w3["alpha"]) + expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) + expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) + expert_scales["w2_alpha"].append(quant_params_w2["alpha"]) + + return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type + + +def _find_final_hidden_state_node( + pattern_output_nodes: list[Node], end_boundary: Node +) -> Optional[Node]: + """ + Identify the final hidden state node corresponding to the combine pattern: + + (expert_output * routing_weight) → index_add_ + + For each expert output node (from the expert compute pattern), this function: + 1. Retrieves a multiplication node from its users. + 2. Extracts the second argument from the multiplication node (assumed to be the index node). + 3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary). + + After collecting all such index_add_ nodes, the final hidden state node is determined + as the one that is not used by any of the other index_add_ nodes. + + If any required attribute (users or args) is missing during the process or if no valid + final node is found, the function returns None. + """ + + if not pattern_output_nodes: + return None + + index_add_nodes = [] + for node in pattern_output_nodes: + if not node.users: + return None + mul_node = next(iter(node.users)) + if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2): + return None + index_node = mul_node.args[1] + index_add_node = bfs( + index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary + ) + if not index_add_node: + return None + index_add_nodes.append(index_add_node) + + # The final node is defined as the index_add_node that is not used by any other index_add_nodes + return next( + ( + candidate + for candidate in index_add_nodes + if not any( + candidate in other.args for other in index_add_nodes if candidate is not other + ) + ), + None, + ) + + +def _extract_index_branches_from_expert_outputs( + pattern_output_nodes: list[Node], +) -> tuple[list[Node], list[Node]]: + """ + Extract routing and experts branches from expert outputs. + + For each expert output, find its multiplication user. From the + multiplication node's second argument (an index node), + extract: + - The first argument as the routing branch. + - The second argument (flattened if a list/tuple) as the experts branch. + + Returns: + A tuple (routing_branches, experts_branches). + """ + routing_branches, experts_branches = [], [] + for out in pattern_output_nodes: + mul = next((u for u in out.users if is_op(u, torch.ops.aten.mul)), None) + if not mul or len(mul.args) < 2: + continue + idx_node = mul.args[1] + if not is_op(idx_node, torch.ops.aten.index): + continue + routing_branches.append(idx_node.args[0]) + experts = idx_node.args[1] + experts_branches.extend(experts) if isinstance( + experts, (list, tuple) + ) else experts_branches.append(experts) + return routing_branches, experts_branches + + +def _remove_dead_inplace_nodes_in_region( + graph: torch.fx.Graph, + start_boundary: torch.fx.Node, + end_boundary: torch.fx.Node, +) -> bool: + """ + Searches (via BFS) for a dead in-place node (index_add_) in the region + between start_boundary and end_boundary. If one is found, it is removed from the graph. + Returns True if a node was removed, False otherwise. + """ + + def target(n: torch.fx.Node) -> bool: + return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0 + + try: + node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) + graph.erase_node(node_to_remove) + return True + except RuntimeError: + return False + + +@TransformRegistry.register("match_moe_pattern") +class MatchMoePattern(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + + # Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph. + boundary_nodes = identify_regions_between_residuals(gm) + + num_moe_patterns = 0 + + for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): + # Step 1: Identify Expert Compute pattern + ( + pattern_input_nodes, + pattern_output_nodes, + expert_weights, + expert_scales, + weight_type, + ) = _match_expert_compute_pattern(start_boundary, end_boundary) + if not expert_weights: + continue + # TODO: naming convention to verify the order of the weight nodes + + # Step 2: Trace upwards to locate normalize_routing_weight and selected_experts: + arg1_list, arg2_list = _extract_index_branches_from_expert_outputs(pattern_output_nodes) + normalized_routing_weights = _find_lowest_common_ancessor(arg1_list) + if not normalized_routing_weights: + continue + + common_ancessor2 = _find_lowest_common_ancessor(arg2_list) + if not common_ancessor2: + continue + selected_experts = bfs( + common_ancessor2, + lambda node: is_op(node, torch.ops.aten.one_hot), + attr_next="all_input_nodes", + boundary=start_boundary, + ).args[0] + if not selected_experts: + continue + + # Step 3: Trace upwards to find input node: + hidden_states = _find_lowest_common_ancessor(pattern_input_nodes) + if not hidden_states: + continue + + # Step 4: Find output node with the combine pattern + final_hidden_state_node = _find_final_hidden_state_node( + pattern_output_nodes, end_boundary + ) + if final_hidden_state_node is None: + continue + + # Step 5: Insert the MoE op into the graph. + with graph.inserting_before(final_hidden_state_node): + w1_list = expert_weights["w1"] + w2_list = expert_weights["w2"] + w3_list = expert_weights["w3"] + + if weight_type == "fp8": + fused_moe_node = graph.call_function( + torch.ops.auto_deploy.torch_quant_fp8_moe, + args=( + hidden_states, + selected_experts, + normalized_routing_weights, + w1_list, + w2_list, + w3_list, + expert_scales["w1_input_scale"], + expert_scales["w2_input_scale"], + expert_scales["w3_input_scale"], + expert_scales["w1_weight_scale"], + expert_scales["w2_weight_scale"], + expert_scales["w3_weight_scale"], + ), + ) + elif weight_type == "fp4": + fused_moe_node = graph.call_function( + torch.ops.auto_deploy.torch_quant_fp4_moe, + args=( + hidden_states, + selected_experts, + normalized_routing_weights, + w1_list, + w2_list, + w3_list, + expert_scales["w1_input_scale"], + expert_scales["w2_input_scale"], + expert_scales["w3_input_scale"], + expert_scales["w1_weight_scale"], + expert_scales["w2_weight_scale"], + expert_scales["w3_weight_scale"], + expert_scales["w1_alpha"], + expert_scales["w2_alpha"], + expert_scales["w3_alpha"], + ), + ) + else: + fused_moe_node = graph.call_function( + torch.ops.auto_deploy.torch_moe, + args=( + hidden_states, + selected_experts, + normalized_routing_weights, + w1_list, + w2_list, + w3_list, + ), + ) + + final_hidden_state_node.replace_all_uses_with(fused_moe_node) + graph.erase_node(final_hidden_state_node) + + while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary): + gm.graph.eliminate_dead_code() + + num_moe_patterns += 1 + + info = TransformInfo( + skipped=False, num_matches=num_moe_patterns, is_clean=False, has_valid_shapes=False + ) + return gm, info + + +@TransformRegistry.register("fuse_moe") +class FuseMoe(BaseTransform): + """ + Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with + torch.ops.auto_deploy.trtllm_moe_fused. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + with cuda_memory_tracker(): + fused_key_counter = _insert_fused_moe_ops(gm) + + info = TransformInfo( + skipped=False, num_matches=fused_key_counter, is_clean=False, has_valid_shapes=False + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py new file mode 100644 index 0000000000..caaaea0934 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py @@ -0,0 +1,54 @@ +"""A simple wrapper transform to build a model via the model factory.""" + +from typing import Optional, Tuple, Type + +from pydantic import Field +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...transformations._graph import move_to_device +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +class LoadWeightsToDeviceConfig(TransformConfig): + """Configuration for the load weights transform.""" + + device: str = Field(default="meta", description="The device to load the weights on.") + adconfig_checkpoint_device: Optional[str] = Field( + default=None, description="Optional checkpoint device argument from adconfig." + ) + + +@TransformRegistry.register("load_weights") +class LoadWeightsToDevice(BaseTransform): + """A simple wrapper transform to load weights into a model.""" + + config: LoadWeightsToDeviceConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return LoadWeightsToDeviceConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + factory.load_or_random_init( + gm, device=self.config.adconfig_checkpoint_device or self.config.device + ) + move_to_device(gm, self.config.device) + cm.to(self.config.device) + + info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 8cf3630b82..7f0a55b9ee 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -1,6 +1,5 @@ -from collections import defaultdict from functools import partial -from typing import Dict, Tuple +from typing import Tuple import torch.nn as nn from torch.fx import GraphModule, Node @@ -21,7 +20,7 @@ from ...utils.quantization_utils import ( remove_output_quantizers, should_skip_quantization, ) -from ..interface import BaseTransform, TransformInfo, TransformRegistry +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry def _insert_quantized_linear( @@ -166,67 +165,99 @@ def _insert_quantized_bmm( node.args = (*node.args, *scale_values) -@TransformRegistry.register("quantize") -class Quantization(BaseTransform): - """Quantize the GraphModule and replace linear/BMM with quantized linear/BMM.""" +@TransformRegistry.register("quantize_from_config") +class QuantizationFromConfig(BaseTransform): + """ + Quantize linear and BMM ops using a quantization config. + + Replaces eligible ops with quantized equivalents based on the quantization algorithm + and exclude patterns defined in the config. + """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - # extract info from quant_config quant_config = factory.get_quant_config() if not quant_config: return gm, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - - is_quant_graph = is_quantized_graph(gm) - quant_algo = quant_config.get("quant_algo") + quant_algo = quant_config.get("quant_algo", None) excluded_patterns = quant_config.get("exclude_modules", []) if not quant_algo: return gm, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - # tracking quantized operations in the graph - quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) + num_matches = 0 + for n in gm.graph.nodes: if should_skip_quantization(n, excluded_patterns): continue - # Process linear operations if is_linear_op(n, include_quantization=False): - # get per-layer quantization format from the node - quant_algo_n: str = ( - get_quantization_from_linear_node(n) if is_quant_graph else quant_algo - ) - if not quant_algo_n: - continue + impl = QuantizationImpl.create(quant_algo, is_bmm=False) + _insert_quantized_linear(gm, n, impl, False) + num_matches += 1 - # insert quantized linear node - _insert_quantized_linear( - gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph - ) - quantized_nodes[quant_algo_n]["linear"] += 1 - - # Process BMM operations + # TODO: Make _insert_quantized_bmm return a bool and increment only on success elif is_bmm_op(n): - if not quant_algo: - continue - - # insert quantized bmm node - _insert_quantized_bmm( - gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph - ) - quantized_nodes[quant_algo]["bmm"] += 1 - - if is_quant_graph: - remove_output_quantizers(gm) - - num_matches = 0 - for quant_algo in quantized_nodes: - for op_type, count in quantized_nodes[quant_algo].items(): - num_matches += count + impl = QuantizationImpl.create(quant_algo, is_bmm=True) + _insert_quantized_bmm(gm, n, impl, False) + num_matches += 1 + + info = TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True + ) + + return gm, info + + +@TransformRegistry.register("quantize_from_graph") +class QuantizationFromGraph(BaseTransform): + """ + Fuse ModelOpt-quantized linear ops into fused quantized implementations. + + Detects quantized nodes from ModelOpt checkpoints's graph and replaces them with + fused linear ops based on the quantization type. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + is_quant_graph = is_quantized_graph(gm) + + # no quantization to do + if not is_quant_graph: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + # tracking quantized operations in the graph + num_matches = 0 + for n in gm.graph.nodes: + # Process linear operations + if is_linear_op(n, include_quantization=False): + # get per-layer quantization format from the node + quant_algo_n: str = get_quantization_from_linear_node(n) + if not quant_algo_n: + continue + + # insert quantized linear node + _insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), True) + num_matches += 1 + + # To check: quant BMM does not have graph based pass? + + remove_output_quantizers(gm) info = TransformInfo( skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py index b7b24cd5d5..e930543aef 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py @@ -9,7 +9,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import is_op from ...utils.quantization_utils import QuantizationImpl, should_skip_quantization -from ..interface import BaseTransform, TransformInfo, TransformRegistry +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry quantized_moe_op_map = { "FP8": torch.ops.auto_deploy.torch_quant_fp8_moe, @@ -139,7 +139,11 @@ class QuantizeMOE(BaseTransform): """ def _apply( - self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: quant_config = factory.get_quant_config() quant_algo = quant_config.get("quant_algo") if quant_config else None diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py similarity index 58% rename from tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py rename to tensorrt_llm/_torch/auto_deploy/transform/library/rope.py index 65e7f7f614..9707dcfd06 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py @@ -47,15 +47,23 @@ TODO: Support other variants: import operator from collections import defaultdict -from typing import Any, DefaultDict, Dict, Optional, Sequence +from typing import Any, DefaultDict, Dict, Optional, Sequence, Tuple, Type import torch +from pydantic import Field from torch.fx import GraphModule, Node -from ...utils.logger import ad_logger +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import extract_op_args, extract_output_tuple, is_op from ...utils.pattern_matcher import ADPatternMatcherPass, Match, register_ad_pattern -from .._graph import canonicalize_graph +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) def _rotate_half(x): @@ -119,221 +127,270 @@ def _explicit_not_interleaved(match: Match) -> bool: return not any(isinstance(n, Node) and _match_input_interleave_pattern(n) for n in (q, k)) -def match_rope_pattern(gm: GraphModule) -> int: - graph = gm.graph - patterns = ADPatternMatcherPass() +@TransformRegistry.register("match_rope_pattern") +class MatchRopePattern(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + patterns = ADPatternMatcherPass() - # dummy shapes: can be arbitrary - batch_size = 8 - seq_len = 16 - num_heads = 8 - hidden_size = 512 - head_dim = hidden_size // num_heads + # dummy shapes: can be arbitrary + batch_size = 8 + seq_len = 16 + num_heads = 8 + hidden_size = 512 + head_dim = hidden_size // num_heads - dummy_explicit = [ - torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16), - torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16), - torch.randn(batch_size, seq_len, head_dim, device="meta", dtype=torch.float16), - torch.randn(batch_size, seq_len, head_dim, device="meta", dtype=torch.float16), - ] - dummy_complex = [ - torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16), - torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16), - torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float16), - ] - # float32 input can change the graph when there's .float() in pattern - dummy_complex_2 = [ - torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32), - torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32), - torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float32), - ] - register_ad_pattern( - search_fn=_explicit_rope_pattern, - replace_fn=_explicit_rope_repl, - patterns=patterns, - dummy_args=dummy_explicit, - op_ignore_types={torch.ops.aten.slice.Tensor: (int,)}, - scalar_workaround={"unsqueeze_dim": 1}, - extra_check=_explicit_not_interleaved, - ) - register_ad_pattern( - search_fn=_interleaved_rope_pattern, - replace_fn=_interleaved_rope_repl, - patterns=patterns, - dummy_args=dummy_explicit, - op_ignore_types={ - torch.ops.aten.slice.Tensor: (int,), - torch.ops.aten.reshape.default: (int,), - torch.ops.aten.view.default: (int,), - }, - scalar_workaround={"unsqueeze_dim": 1}, - ) - register_ad_pattern( - search_fn=_complex_rope_pattern, - replace_fn=_complex_rope_repl, - patterns=patterns, - dummy_args=dummy_complex, - op_ignore_types={ - torch.ops.aten.reshape.default: (int,), - }, - scalar_workaround={"unsqueeze_dim": 1}, - ) - register_ad_pattern( - search_fn=_complex_rope_pattern, - replace_fn=_complex_rope_repl, - patterns=patterns, - dummy_args=dummy_complex_2, - op_ignore_types={ - torch.ops.aten.reshape.default: (int,), - }, - scalar_workaround={"unsqueeze_dim": 1}, + dummy_explicit = [ + torch.randn( + batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16 + ), + torch.randn( + batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16 + ), + torch.randn(batch_size, seq_len, head_dim, device="meta", dtype=torch.float16), + torch.randn(batch_size, seq_len, head_dim, device="meta", dtype=torch.float16), + ] + dummy_complex = [ + torch.randn( + batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16 + ), + torch.randn( + batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16 + ), + torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float16), + ] + dummy_complex_2 = [ + torch.randn( + batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32 + ), + torch.randn( + batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32 + ), + torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float32), + ] + register_ad_pattern( + search_fn=_explicit_rope_pattern, + replace_fn=_explicit_rope_repl, + patterns=patterns, + dummy_args=dummy_explicit, + op_ignore_types={torch.ops.aten.slice.Tensor: (int,)}, + scalar_workaround={"unsqueeze_dim": 1}, + extra_check=_explicit_not_interleaved, + ) + register_ad_pattern( + search_fn=_interleaved_rope_pattern, + replace_fn=_interleaved_rope_repl, + patterns=patterns, + dummy_args=dummy_explicit, + op_ignore_types={ + torch.ops.aten.slice.Tensor: (int,), + torch.ops.aten.reshape.default: (int,), + torch.ops.aten.view.default: (int,), + }, + scalar_workaround={"unsqueeze_dim": 1}, + ) + register_ad_pattern( + search_fn=_complex_rope_pattern, + replace_fn=_complex_rope_repl, + patterns=patterns, + dummy_args=dummy_complex, + op_ignore_types={ + torch.ops.aten.reshape.default: (int,), + }, + scalar_workaround={"unsqueeze_dim": 1}, + ) + register_ad_pattern( + search_fn=_complex_rope_pattern, + replace_fn=_complex_rope_repl, + patterns=patterns, + dummy_args=dummy_complex_2, + op_ignore_types={ + torch.ops.aten.reshape.default: (int,), + }, + scalar_workaround={"unsqueeze_dim": 1}, + ) + + num_matches = patterns.apply(graph) + + info = TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False + ) + + return gm, info + + +class MatchRopeLayoutConfig(TransformConfig): + """Configuration for the match rope layout transform.""" + + expected_layout: str = Field( + default="bsnd", + description="The expected layout of the rope operation. Must be one of 'bsnd' or 'bnsd'.", ) - num_matches = patterns.apply(graph) - canonicalize_graph(gm) - ad_logger.info(f"Found and matched {num_matches} RoPE patterns") - return num_matches - -def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> None: +@TransformRegistry.register("match_rope_layout") +class MatchRopeLayout(BaseTransform): """ Match and transform input and output of rope ops to the layout specified to meet requirements of optimized ops. Supported layout is 'bsnd' (batch, seq, head, dim). """ - supported = {"bsnd", "bnsd"} - if expected_layout.lower() not in supported: - ad_logger.warning( - f"Unsupported RoPE layout '{expected_layout}'; expected '{supported}'. Skipping RoPE layout matching." + + config: MatchRopeLayoutConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return MatchRopeLayoutConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + supported = {"bsnd", "bnsd"} + if self.config.expected_layout.lower() not in supported: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + graph = gm.graph + rope_ops = { + torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin, + torch.ops.auto_deploy.torch_rope_with_qk_interleaving, + torch.ops.auto_deploy.torch_rope_with_complex_freqs, + } + + need_transpose = False + num_rope_layout_matches = 0 + for node in graph.nodes: + if not is_op(node, rope_ops): + continue + + if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): + q_node, k_node, freqs_node, unsq = extract_op_args( + node, + "xq", # argument name in schema + "xk", + "freqs_cis", + "unsqueeze_dim", + ) + else: + q_node, k_node, cos_node, sin_node, unsq = extract_op_args( + node, "q", "k", "cos", "sin", "unsqueeze_dim" + ) + + if unsq == 2: + current_layout = "bsnd" + elif unsq == 1: + current_layout = "bnsd" + else: + continue + + need_transpose = self.config.expected_layout.lower() != current_layout + + if not need_transpose: + continue + + num_rope_layout_matches += 1 + # retrieve q and k output node from node + q_rope_old, k_rope_old = extract_output_tuple(node, 2) + if q_rope_old is None or k_rope_old is None: + continue + + with graph.inserting_before(node): + q_for_op = graph.call_function(torch.ops.aten.transpose, args=(q_node, 1, 2)) + k_for_op = graph.call_function(torch.ops.aten.transpose, args=(k_node, 1, 2)) + q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,)) + k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,)) + + q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2) + k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2) + + if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): + new_args = ( + q_for_op_contig, + k_for_op_contig, + freqs_node, + 2 if self.config.expected_layout.lower() == "bsnd" else 1, + ) # unsqueeze_dim updated + else: + new_args = ( + q_for_op_contig, + k_for_op_contig, + cos_node, + sin_node, + 2 if self.config.expected_layout.lower() == "bsnd" else 1, + ) # unsqueeze_dim updated + node.args = new_args + + with graph.inserting_after(q_rope_old): + q_rope_new = graph.call_function(torch.ops.aten.transpose, args=(q_rope_old, 1, 2)) + with graph.inserting_after(k_rope_old): + k_rope_new = graph.call_function(torch.ops.aten.transpose, args=(k_rope_old, 1, 2)) + + # Preserve fake tensor in meta["val"] for the transposed inputs + q_rope_new.meta["val"] = q_rope_old.meta["val"] + q_rope_old.meta["val"] = q_rope_old.meta["val"].transpose(1, 2) + k_rope_new.meta["val"] = k_rope_old.meta["val"] + k_rope_old.meta["val"] = k_rope_old.meta["val"].transpose(1, 2) + + q_rope_old.replace_all_uses_with(q_rope_new) + k_rope_old.replace_all_uses_with(k_rope_new) + q_rope_new.args = (q_rope_old, 1, 2) + k_rope_new.args = (k_rope_old, 1, 2) + + info = TransformInfo( + skipped=False, + num_matches=num_rope_layout_matches, + is_clean=False, + has_valid_shapes=False, ) - return - ad_logger.info(f"Match RoPE layout to {expected_layout}") - - graph = gm.graph - rope_ops = { - torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin, - torch.ops.auto_deploy.torch_rope_with_qk_interleaving, - torch.ops.auto_deploy.torch_rope_with_complex_freqs, - } - - need_transpose = False - num_rope_layout_matches = 0 - for node in graph.nodes: - if not is_op(node, rope_ops): - continue - - if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): - q_node, k_node, freqs_node, unsq = extract_op_args( - node, - "xq", # argument name in schema - "xk", - "freqs_cis", - "unsqueeze_dim", - ) - else: - q_node, k_node, cos_node, sin_node, unsq = extract_op_args( - node, "q", "k", "cos", "sin", "unsqueeze_dim" - ) - - if unsq == 2: - current_layout = "bsnd" - elif unsq == 1: - current_layout = "bnsd" - else: - ad_logger.warning( - "Unsqueeze_dim is not one of [1, 2]. " - "Unable to infer layout of q node. Skip layout matching" - ) - continue - - need_transpose = expected_layout.lower() != current_layout - - if not need_transpose: - continue - - num_rope_layout_matches += 1 - # retrieve q and k output node from node - q_rope_old, k_rope_old = extract_output_tuple(node, 2) - if q_rope_old is None or k_rope_old is None: - ad_logger.warning( - f"Failed to extract all two outputs from the explicit op, \ - get {q_rope_old}, {k_rope_old}, fail to match rope layout with {node} with" - ) - continue - - ad_logger.debug( - f"Inferred RoPE input layout: '{current_layout}']Mapping layout to '{expected_layout}']" - ) - with graph.inserting_before(node): - q_for_op = graph.call_function(torch.ops.aten.transpose, args=(q_node, 1, 2)) - k_for_op = graph.call_function(torch.ops.aten.transpose, args=(k_node, 1, 2)) - q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,)) - k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,)) - - q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2) - k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2) - - if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): - new_args = ( - q_for_op_contig, - k_for_op_contig, - freqs_node, - 2 if expected_layout.lower() == "bsnd" else 1, - ) # unsqueeze_dim updated - else: - new_args = ( - q_for_op_contig, - k_for_op_contig, - cos_node, - sin_node, - 2 if expected_layout.lower() == "bsnd" else 1, - ) # unsqueeze_dim updated - node.args = new_args - - with graph.inserting_after(q_rope_old): - q_rope_new = graph.call_function(torch.ops.aten.transpose, args=(q_rope_old, 1, 2)) - with graph.inserting_after(k_rope_old): - k_rope_new = graph.call_function(torch.ops.aten.transpose, args=(k_rope_old, 1, 2)) - - # Preserve fake tensor in meta["val"] for the transposed inputs - q_rope_new.meta["val"] = q_rope_old.meta["val"] - q_rope_old.meta["val"] = q_rope_old.meta["val"].transpose(1, 2) - k_rope_new.meta["val"] = k_rope_old.meta["val"] - k_rope_old.meta["val"] = k_rope_old.meta["val"].transpose(1, 2) - - q_rope_old.replace_all_uses_with(q_rope_new) - k_rope_old.replace_all_uses_with(k_rope_new) - q_rope_new.args = (q_rope_old, 1, 2) - k_rope_new.args = (k_rope_old, 1, 2) - - if num_rope_layout_matches: - canonicalize_graph(gm) - ad_logger.info(f"Found {num_rope_layout_matches} RoPE layout matches") + return gm, info -def optimize_rope(gm: GraphModule) -> None: +@TransformRegistry.register("optimize_rope") +class OptimizeRope(BaseTransform): """ Scan the FX graph and replace calls to the torch-reference RoPE ops with the optimized `rope::flashinfer` kernel. Precomputes positional IDs and the fused cosine-sine cache as explicit nodes, and reuses those nodes when possible. """ - graph = gm.graph - rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None) - rope_position_ids_cache: Dict[str, Node] = {} - num_rope_optimizations = 0 - for node in list(graph.nodes): - if is_op(node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin): - _optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache) - elif is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): - _optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache) - else: - continue - num_rope_optimizations += 1 - if num_rope_optimizations: - canonicalize_graph(gm) - ad_logger.info(f"Found {num_rope_optimizations} RoPE optimizations") + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None) + rope_position_ids_cache: Dict[str, Node] = {} + + num_rope_optimizations = 0 + for node in list(graph.nodes): + if is_op(node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin): + _optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache) + elif is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): + _optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache) + else: + continue + num_rope_optimizations += 1 + + info = TransformInfo( + skipped=False, num_matches=num_rope_optimizations, is_clean=False, has_valid_shapes=True + ) + + return gm, info def _optimize_explicit( @@ -344,10 +401,6 @@ def _optimize_explicit( # retrieve q and k output node from node q_rope_old, k_rope_old = extract_output_tuple(node, 2) if q_rope_old is None or k_rope_old is None: - ad_logger.warning( - f"Failed to extract all two outputs from the explicit op, \ - get {q_rope_old}, {k_rope_old}, fail to replace {node} with flashinfer rope" - ) return # Sanity check on head_dim @@ -358,15 +411,8 @@ def _optimize_explicit( q_fake = q_node.meta.get("val", None) if q_fake is not None and len(q_fake.shape) > 2: if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)): - ad_logger.warning( - f"""Sanity check failed: q_fake should have shape [b, s, n, d], - s should be symbolic and n should be int, instead got shape {q_fake.shape}""" - ) return elif q_fake is not None: - ad_logger.warning( - f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}" - ) return head_dim = cos_node.meta["val"].shape[-1] @@ -449,15 +495,8 @@ def _optimize_complex( q_fake = q_node.meta.get("val", None) if q_fake is not None and len(q_fake.shape) > 2: if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)): - ad_logger.warning( - f"""Sanity check failed: q_fake should have shape [b, s, n, d], - s should be symbolic and n should be int, instead got shape {q_fake.shape}""" - ) return elif q_fake is not None: - ad_logger.warning( - f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}" - ) return # Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash @@ -522,35 +561,6 @@ def _match_input_interleave_pattern(node: Node) -> Optional[Dict[str, Node]]: return {"interleaved": raw_node} -def _move_node_before_first_user(node: Node) -> Node: - """ - Remove `node` from the graph and re-insert a clone of it immediately - before its earliest user. Returns the new node. - - If `node` has no users, or is already right before its first user, - this is a no-op and returns the original node. - """ - graph = node.graph - ordering = list(graph.nodes) - - users = list(node.users) - if not users: - return node - - # locate the earliest user in the current ordering - first_user = min(users, key=lambda u: ordering.index(u)) - if ordering.index(node) == ordering.index(first_user) - 1: - return node - - with graph.inserting_before(first_user): - new_node = graph.node_copy(node, lambda n: n) - - node.replace_all_uses_with(new_node) - graph.erase_node(node) - - return new_node - - def _get_last_node(nodes: Sequence[Node]) -> Node: """ Given a list of FX Nodes, @@ -581,36 +591,21 @@ def _validate_rope_inputs(q_node: Node, k_node: Node) -> bool: for name, node in [("q", q_node), ("k", k_node)]: fake_val = node.meta.get("val", None) if fake_val is None: - ad_logger.warning( - f"Meta['val'] for {name} not available; skipping RoPE transformation." - ) return False # Check dtype if fake_val.dtype not in (torch.float16, torch.bfloat16): - ad_logger.warning( - f"""{name} tensor is {fake_val.dtype}, - expected half precision (float16 or bfloat16). Skipping RoPE transformation.""" - ) return False # Check head_dim if len(fake_val.shape) < 1: - ad_logger.warning(f"{name} tensor has invalid shape {fake_val.shape}.") return False head_dim = fake_val.shape[-1] if isinstance(head_dim, int) and head_dim % 64 != 0: - ad_logger.warning( - f"{name} head_dim = {head_dim} is not a multiple of 64. Skipping RoPE transformation." - ) return False # Check shape if not isinstance(fake_val.shape[1], torch.SymInt): - ad_logger.warning( - f"{name} has shape {fake_val.shape} that is not supported. Only support [B, S, N, D] layout.\ - Skipping RoPE transformation." - ) return False return True diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py new file mode 100644 index 0000000000..b4ed58c5d3 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -0,0 +1,452 @@ +"""Transformations to support graph sharding. + +Our sharding algorithm for tensor parallelism (TP) is based on the following steps: + + 1. Initialize/construct unsharded model. Ideally, this should be done on device="meta" to avoid + unnecessary memory allocation. In some cases, this is necessary if the model is too large to + fit on a single device. + 2. Shard the graph IR of the model: + a. Identify linear nodes that correspond to TP tuples. + b. Reduce/Shard shape of weights in the corresponding linear nodes accordingly (either in + row or column dimension). Add all_reduce nodes where necessary (--> only needed for + fusing results in final linear node of the TP tuple). + c. Add a checkpoint loading hook to the sharded linear nodes so that only the correct shard + of the weight from the checkpoint gets loaded. + 3. Load the checkpoint and allocate the tensor. Loading the correct shard from the checkpoint + happens automatically via the checkpoint loading hook added in step 2c. +""" + +import operator +from collections import defaultdict +from typing import DefaultDict, Dict, List, Set, Tuple, Type + +import torch +from pydantic import Field +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.logger import ad_logger +from ...utils.node_utils import identify_regions_between_residuals, is_linear_op, is_op +from ...utils.sharding_utils import ( + BMMShardingInfo, + EPShardingInfo, + ShardingConfig, + ShardingTransformInfo, + SplitDimension, + TPShardingInfo, +) +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +@TransformRegistry.register("sharding_transform_executor") +class ShardingTransformExecutor(BaseTransform): + """Apply transformations to the graph module. + + Args: + gm: Graph module to apply transformations to + sharding_config: Transformation configuration containing list of transformations to apply + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # create a node dict for faster lookup + node_dict = {n.name: n for n in gm.graph.nodes} + + def check_and_apply(transform: ShardingTransformInfo) -> bool: + """Return True if the transformation is applied, False otherwise.""" + if transform.target_node is None or transform.target_node not in node_dict: + ad_logger.warning( + f"Skipping transformation {transform} because target node " + + f"{transform.target_node} not found in graph" + ) + return False + return transform.check_and_apply(gm, node_dict[transform.target_node]) + + num_matches = 0 + for tp_transform in shared_config.sharding_config.tp_transforms: + if check_and_apply(tp_transform): + num_matches += 1 + for bmm_transform in shared_config.sharding_config.bmm_transforms: + if check_and_apply(bmm_transform): + num_matches += 1 + for ep_transform in shared_config.sharding_config.ep_transforms: + if check_and_apply(ep_transform): + num_matches += 1 + + info = TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False + ) + return gm, info + + +def _append_simple_shard( + nodes_linear: Dict[Node, List[Node]], + rank: int, + world_size: int, + sharding_config: ShardingConfig, +) -> None: + # for every linear node: + # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) + tp_shards: List[TPShardingInfo] = [] + for node_group in nodes_linear.values(): + for n in node_group: + tp_shards.append( + TPShardingInfo( + target_node=n.name, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) + ) + sharding_config.tp_transforms.extend(tp_shards) + + +class ColumnRowShardConfig(TransformConfig): + """Configuration for column-row sharding.""" + + simple_shard_only: bool = Field(default=False) + + +@TransformRegistry.register("detect_column_row_shard") +class ColumnRowShard(BaseTransform): + """A transformation to apply sharding to the model following tensor parallelism. + + The transformation is based on the following steps: + + 1. Identify boundary nodes between residual nodes to identify shardable regions. + 2. Identify the GEMM nodes that can be sharded + 3. Trace through the subgraph using DFS/BFS between each pair of boundary nodes + 4. Account for each node in the trace to ensure the op is correct even after sharding. This is + necessary to ensure that the sharding is correct and we need to be able to account for + **all** nodes in the subgraph. The subgraph here is defined as the region between the first + linear node to the last linear node of an identified sharding region. + # 5. Shard the GEMM nodes or skip accordingly. + + min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism + splitting, e.g., the individual heads into smaller shards. + """ + + config: ColumnRowShardConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return ColumnRowShardConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + local_rank, world_size = shared_config.local_rank, shared_config.world_size + + if world_size < 2: + ad_logger.info("Skipping sharding for single device") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + assert isinstance(gm, GraphModule), "Expecting GraphModule" + + # find boundary nodes of regions we want to shard + boundary_nodes = identify_regions_between_residuals(gm) + + # TODO: continue updating these lists + # pointwise ops that don't affect the sharder + pointwise_ops = { + torch.ops.aten.gelu, + torch.ops.aten.leaky_relu, + torch.ops.aten.mul, + torch.ops.aten.relu, + torch.ops.aten.sigmoid, + torch.ops.aten.silu, + torch.ops.aten.tanh, + torch.ops.aten.contiguous, + } + + # acceptable attention nodes between sharded GEMMs + shardable_attention_nodes = { + torch.ops.auto_deploy.torch_attention_sdpa, + torch.ops.auto_deploy.torch_attention_grouped_sdpa, + torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, + } + + # This is a heuristic. Basically, we assume those are okay to shard if we also encounter an + # attention node because we know that those ops must be compatible with the attention op. Now + # since the attention op is shardable, we will assume those are as well if used in conjunction + # with the attention op. + shardable_nodes_with_attention = { + torch.ops.aten.view, + torch.ops.aten.reshape, + torch.ops.auto_deploy.flashinfer_rope, + operator.getitem, + } + + # let's look at linear nodes we can identify between pairs of boundary nodes + # There is three potential cases we can handle: + # 1. No linear nodes: + # --> just continue + # 2. Two groups of linear nodes and we can account for all to the view nodes: + # --> row_split (dim 0) 1st group + check for supported nodes + + # col_split (dim 1) 2nd group + all_reduce output of 2nd group + # 3. Linear nodes that are not in two groups or we cannot account for all nodes: + # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) output + num_shards = 0 + for n_start, n_end in zip(boundary_nodes[:-1], boundary_nodes[1:]): + # we iterate through all nodes between the two boundary nodes and store linear nodes + # sorted by their input activation node. We also store remaining nodes. + nodes_linear: DefaultDict[Node, List[Node]] = defaultdict(list) + attention_nodes: Set[Node] = set() + attention_related_nodes: Set[Node] = set() + unaccounted_nodes: Set[Node] = set() + current_node = n_start + while current_node != n_end: + if is_linear_op(current_node, include_quantization=True): + nodes_linear[current_node.args[0]].append(current_node) + elif is_op(current_node, shardable_attention_nodes): + attention_nodes.add(current_node) + elif is_op(current_node, shardable_nodes_with_attention): + attention_related_nodes.add(current_node) + elif not is_op(current_node, pointwise_ops): + unaccounted_nodes.add(current_node) + current_node = current_node.next + assert current_node, "Could not identify next node" + + # nothing to shard + if len(nodes_linear) == 0: + continue + + num_shards += 1 + + if self.config.simple_shard_only: + ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config + ) + continue + + # simple shard when we have != 2 groups of linear nodes + if len(nodes_linear) != 2: + ad_logger.debug(f"Linear groups: {nodes_linear}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config + ) + continue + + # let's look at the unnacounted nodes. They are okay as long as they fall before the + # first linear node or after the last linear node, i.e., outside the sharded region + lin_nodes_flat: Set[Node] = {n for group in nodes_linear.values() for n in group} + lin_nodes_passed: Set[Node] = set() + current_node = n_start + while current_node != n_end: + # check if this is another linear node + if current_node in lin_nodes_flat: + lin_nodes_passed.add(current_node) + + # check if we are OUTSIDE sharded region + if len(lin_nodes_passed) == 0 or lin_nodes_passed == lin_nodes_flat: + # remove node from unaccounted nodes since we are outside and it doesn't matter + unaccounted_nodes.discard(current_node) + attention_related_nodes.discard(current_node) + attention_nodes.discard(current_node) + + current_node = current_node.next + + # let's post-process the attention-related nodes + # we can disregard them if we also see attention nodes and we assume they are compatible + if len(attention_nodes) > 0: + attention_related_nodes.clear() + + # check if any unaccounted nodes are left. If so, do a simply shard + if unaccounted_nodes or attention_related_nodes: + ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config + ) + continue + + # If we can account for all sharded nodes, we can do a two-way shard + # --> row_split (dim 0) + col_split (dim 1) + all_reduce + + # check if we are sharding the attention block + if attention_nodes: + if len(attention_nodes) > 1: + # Column-row shard boundary region detection is probably wrong - there should be + # only one attention operation. Fall back to simple shard. + ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config + ) + continue + # Extract head dimension. We cannot shard below the head_dim size. + # Assume that head_dim is the last (innermost) dimension of the tensor + min_local_shape = attention_nodes.pop().meta["val"].shape[-1] + else: + min_local_shape = 1 + for i, group in enumerate(nodes_linear.values()): + for n in group: + if i > 0: + dist_op = "all_reduce" + else: + dist_op = None + shared_config.sharding_config.tp_transforms.append( + TPShardingInfo( + target_node=n.name, + split_dim=i, + rank=local_rank, + world_size=world_size, + dist_op=dist_op, + min_local_shape=min_local_shape, + ) + ) + + info = TransformInfo( + skipped=False, num_matches=num_shards, is_clean=False, has_valid_shapes=False + ) + return gm, info + + +@TransformRegistry.register("detect_dp_bmm_shard") +class DpBmmShard(BaseTransform): + """A transformation to apply sharding to batched matrix multiplications in the graph. + + We'll shard the BMM nodes by slicing the batch dimension of input tensors into world_size number of slices. + After sharding each BMM node, we'll insert an all_gather node to gather the results across the different devices. + This transformation handles any combination of tensor types for both inputs to the BMM operation. + + We'll also assume that the inputs to BMM are broadcasted across the devices already. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + local_rank, world_size = shared_config.local_rank, shared_config.world_size + if world_size < 2: + ad_logger.info("Skipping sharding for single device") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + assert isinstance(gm, GraphModule), "Expecting GraphModule" + + num_bmm_shards = 0 + + for node in gm.graph.nodes: + if not is_op(node, {torch.ops.aten.bmm}): + continue + + # Get the input tensors + lhs_tensor = node.args[0] + rhs_tensor = node.args[1] + + # Check batch sizes from meta information + lhs_batch_size = lhs_tensor.meta["val"].shape[0] + rhs_batch_size = rhs_tensor.meta["val"].shape[0] + + assert lhs_batch_size == rhs_batch_size, "Batch sizes of both tensors must match" + bmm_batch_size = lhs_batch_size + + # Calculate balanced distribution + base_size = bmm_batch_size // world_size + remainder = bmm_batch_size % world_size + + # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. + if remainder: + ad_logger.warning( + f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. " + f"This will result in uneven distribution of work across devices. Skipping." + ) + continue + + # Calculate start and end indices for this rank + if local_rank < remainder: + start_idx = local_rank * (base_size + 1) + end_idx = start_idx + base_size + 1 + else: + start_idx = remainder + local_rank * base_size + end_idx = start_idx + base_size + + shared_config.sharding_config.bmm_transforms.append( + BMMShardingInfo( + target_node=node.name, + rank=local_rank, + world_size=world_size, + start_idx=start_idx, + end_idx=end_idx, + ) + ) + ad_logger.debug( + f"Sharding BMM for rank {local_rank}: " + f"batch_size={bmm_batch_size}, " + f"start_idx={start_idx}, end_idx={end_idx}" + ) + + num_bmm_shards += 1 + + info = TransformInfo( + skipped=False, num_matches=num_bmm_shards, is_clean=False, has_valid_shapes=False + ) + return gm, info + + +@TransformRegistry.register("detect_ep_shard") +class DetectEpShard(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + local_rank, world_size = shared_config.local_rank, shared_config.world_size + + if world_size < 2: + ad_logger.info("Skipping sharding for single device") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + assert isinstance(gm, GraphModule), "Expecting GraphModule" + num_moe_patterns = 0 + for node in list(gm.graph.nodes): + if not is_op( + node, + ( + torch.ops.auto_deploy.torch_moe, + torch.ops.auto_deploy.torch_quant_fp8_moe, + torch.ops.auto_deploy.torch_quant_fp4_moe, + ), + ): + continue + shared_config.sharding_config.ep_transforms.append( + EPShardingInfo( + target_node=node.name, + rank=local_rank, + world_size=world_size, + ) + ) + num_moe_patterns += 1 + + info = TransformInfo( + skipped=False, num_matches=num_moe_patterns, is_clean=False, has_valid_shapes=False + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py index 2aac699327..53659bf814 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py @@ -2,13 +2,16 @@ from typing import Optional +import torch.distributed as dist import torch.nn as nn from torch.fx import Graph, GraphModule +from ..distributed import common as dist_ad from ..models.factory import ModelFactory from ..shim.interface import CachedSequenceInterface from .interface import ( InferenceOptimizerConfig, + SharedConfig, Stages, StrictInferenceOptimizerConfig, TransformConfig, @@ -20,6 +23,11 @@ class InferenceOptimizer: def __init__(self, factory: ModelFactory, config: InferenceOptimizerConfig): self.factory = factory self.config = self._clean_config(config) + if not dist.is_initialized(): + local_rank, world_size = 0, 1 + else: + local_rank, world_size = dist_ad.get_rank_world_size() + self.shared_config = SharedConfig(local_rank=local_rank, world_size=world_size) def _clean_config(self, config: InferenceOptimizerConfig) -> StrictInferenceOptimizerConfig: """Get a typed checked ("strict") config with sorted keys according to stages.""" @@ -68,7 +76,7 @@ class InferenceOptimizer: # instantiate transform transform = TransformRegistry.get(t_name)(t_config) # run transform - gm = transform(gm, cm, self.factory) + gm = transform(gm, cm, self.factory, self.shared_config) ############################################################################################ # RETURN OPTIMIZED GRAPH diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py index 4a39c7f662..0d4c388ebc 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py @@ -1,13 +1,10 @@ """A library of transformation passes.""" from .collectives import * -from .eliminate_redundant_transposes import * from .fused_moe import * from .fusion import * from .kvcache import * from .rms_norm import * -from .rope import * -from .sharding import * try: from .visualization import visualize_namespace diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py deleted file mode 100644 index a8c6668dde..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Graph transformation to eliminate redundant transpose operations in the model graph. - -This transformation identifies and removes patterns where transpose operations with the same -dimensions are applied consecutively, which cancel each other out: -x = x.transpose(1, 2) -x = x.transpose(1, 2) -""" - -from typing import Set, Tuple - -import torch -from torch.fx import GraphModule, Node - -from ...utils.logger import ad_logger -from ...utils.node_utils import is_op -from .._graph import canonicalize_graph - - -def _is_transpose_op(node: Node) -> bool: - """Check if the node is a transpose operation.""" - return is_op(node, torch.ops.aten.transpose) - - -def _is_contiguous_op(node: Node) -> bool: - """Check if the node is a contiguous operation.""" - return is_op(node, torch.ops.aten.contiguous) - - -def _are_transpose_args_same(node1: Node, node2: Node) -> bool: - """Check if two transpose nodes have the same dimension arguments.""" - # Get the dimension arguments for both nodes - # Args structure: (input_tensor, dim1, dim2) - if len(node1.args) < 3 or len(node2.args) < 3: - return False - - dim1_node1, dim2_node1 = node1.args[1], node1.args[2] - dim1_node2, dim2_node2 = node2.args[1], node2.args[2] - - # Check if the dimensions are the same - return dim1_node1 == dim1_node2 and dim2_node1 == dim2_node2 - - -def eliminate_redundant_transposes(gm: GraphModule) -> None: - """Eliminate redundant transpose operations in the graph. - - This transformation identifies pairs of consecutive transpose operations with - the same dimension arguments and removes both operations, as they cancel out. - """ - ad_logger.debug("Before eliminating redundant transposes: " + str(gm)) - - graph = gm.graph - - # Find pairs of redundant transpose operations - nodes_to_eliminate: Set[Tuple[Node, Node]] = set() - - for t_node in gm.graph.nodes: - # check if there is a transpose operation - if not _is_transpose_op(t_node): - continue - - # check if it's already part of a pair - if any(t_node in pair for pair in nodes_to_eliminate): - continue - - # check if there is only one user - if len(t_node.users) > 1: - continue - - # check if the user is a contiguous operation - t_comp_node = list(t_node.users)[0] - - # check if the user is a contiguous operation - has_contiguous = False - while _is_contiguous_op(t_comp_node) and len(t_comp_node.users) == 1: - has_contiguous = True - t_comp_node = list(t_comp_node.users)[0] - - # check if the user is a transpose operation - if not _is_transpose_op(t_comp_node): - continue - - # check if the transpose operation has the same dimension arguments - if not _are_transpose_args_same(t_node, t_comp_node): - continue - - # add the pair to the set - nodes_to_eliminate.add((t_node, t_comp_node, has_contiguous)) - - # Eliminate redundant transpose pairs - for t_node, t_comp_node, has_contiguous in nodes_to_eliminate: - # Replace all uses of the second transpose with the input to the first transpose - original_input = t_node.args[0] - t_comp_node.replace_all_uses_with(original_input) - - # if there is a contiguous operation that we skipped, let add it after t_comp_node as new - # graph node that call contiguous on t_comp_node - if has_contiguous: - with graph.inserting_after(original_input): - new_contiguous_node = graph.call_function( - torch.ops.aten.contiguous.default, args=(original_input,) - ) - original_input.replace_all_uses_with(new_contiguous_node) - new_contiguous_node.replace_input_with(new_contiguous_node, original_input) - - ad_logger.debug(f"Eliminated redundant transpose pair: {t_node} -> {t_comp_node}") - - # Clean up the graph - if nodes_to_eliminate: - gm.graph.eliminate_dead_code() - canonicalize_graph(gm) - ad_logger.info(f"Found and eliminated {len(nodes_to_eliminate)} redundant transpose pairs") - ad_logger.debug("After eliminating redundant transposes: " + str(gm)) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index 3844ce4d31..c841b4601f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -7,29 +7,17 @@ import torch.nn as nn from ..compile import compile_and_capture from ..custom_ops.attention_interface import AttentionRegistry -from ..distributed import common as dist_ad from ..llm_args import AutoDeployConfig from ..models.factory import ModelFactory from ..shim.interface import CachedSequenceInterface from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer from ..utils.logger import ad_logger -from ._graph import canonicalize_graph, lift_to_meta, move_to_device from .library import ( - ShardingConfig, - detect_column_row_shard, - detect_dp_bmm_shard, - detect_ep_shard, - eliminate_redundant_transposes, fuse_allreduce_residual_rmsnorm, fuse_collectives, fuse_rmsnorm, insert_cached_attention, - match_moe_pattern, - match_rope_layout, - match_rope_pattern, - optimize_rope, resize_kv_cache, - sharding_transform_executor, update_in_out_nodes, ) @@ -57,76 +45,28 @@ class InferenceOptimizer: # RUN MODULAR INFERENCE OPTIMIZER FOR ALREADY-MIGRATED TRANSFORMS ############################################################################################ # TODO (hg): default values that are not representable in YAML. + # move to the optimizer if "match_attention_layout" in self.ad_config.transforms: self.ad_config.transforms[ "match_attention_layout" ].attention_op = AttentionRegistry.get(self.ad_config.attn_backend) + if "match_rope_layout" in self.ad_config.transforms: + self.ad_config.transforms["match_rope_layout"].expected_layout = AttentionRegistry.get( + self.ad_config.attn_backend + ).get_attention_layout() new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms) + + # TODO (hg): similar to above. + if "load_weights" in new_optimizer.config: + new_optimizer.config[ + "load_weights" + ].checkpoint_device = self.ad_config.checkpoint_device + new_optimizer.config["load_weights"].device = cm.device + egm = new_optimizer(cm) # TODO (lucaslie): continue moving legacy transforms to the new optimizer - - ############################################################################################ - # RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION - ############################################################################################ - - # Match MoE pattern - match_moe_pattern(egm) - - # Match rope - match_rope_pattern(egm) - - # Match RoPE layout expected by our backend - match_rope_layout( - egm, AttentionRegistry.get(self.ad_config.attn_backend).get_attention_layout() - ) - - ############################################################################################ - # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION - ############################################################################################ - - local_rank, world_size = dist_ad.get_rank_world_size() - - # eliminate redundant transpose operations - eliminate_redundant_transposes(egm) - - # TODO (lucaslie): let's move this to perf optimization once TP sharding is improved - # see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528 - optimize_rope(egm) - - # TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config. - sharding_config = ShardingConfig() - - # run TP sharding across ranks - detect_column_row_shard( - egm, local_rank, world_size, sharding_config, self.ad_config.simple_shard_only - ) - - # run EP sharding across ranks - detect_ep_shard(egm, local_rank, world_size, sharding_config) - - # run BMM sharding across ranks - detect_dp_bmm_shard(egm, local_rank, world_size, sharding_config) - - sharding_transform_executor(egm, sharding_config) - - # let's run a shape propagation pass to update the graph with correct meta values for - # subsequent optimization passes. Lift state_dict to meta as shape propagation involves device check - with lift_to_meta(egm): - canonicalize_graph(egm, shape_prop=True) - - ############################################################################################ - # MOVE MODEL AND LOAD WEIGHTS - ############################################################################################ - - # load weights - self.factory.load_or_random_init(egm, device=self.ad_config.checkpoint_device or cm.device) - - # move remaining parts to device - move_to_device(egm, cm.device) - cm.to(cm.device) - ############################################################################################ # RUN POST-LOAD FUSION AND OPTIMIZATIONS ############################################################################################ diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index f207584518..0587874598 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -304,7 +304,7 @@ class FP4QuantizationImpl(QuantizationImpl): weight_scale = state_dict[weight_name + "_scale"].view(float4_sf_dtype) ori_shape = weight_scale.shape state_dict[weight_name + "_scale"] = ( - torch.ops.trtllm.nvfp4_block_scale_interleave( + torch.ops.trtllm.block_scale_interleave( weight_scale.view(torch.uint8).cpu().contiguous() ) .reshape(ori_shape) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py similarity index 52% rename from tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py rename to tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index d7ed5918a4..e0c8cd65ca 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -1,287 +1,20 @@ -"""Transformations to support graph sharding. - -Our sharding algorithm for tensor parallelism (TP) is based on the following steps: - - 1. Initialize/construct unsharded model. Ideally, this should be done on device="meta" to avoid - unnecessary memory allocation. In some cases, this is necessary if the model is too large to - fit on a single device. - 2. Shard the graph IR of the model: - a. Identify linear nodes that correspond to TP tuples. - b. Reduce/Shard shape of weights in the corresponding linear nodes accordingly (either in - row or column dimension). Add all_reduce nodes where necessary (--> only needed for - fusing results in final linear node of the TP tuple). - c. Add a checkpoint loading hook to the sharded linear nodes so that only the correct shard - of the weight from the checkpoint gets loaded. - 3. Load the checkpoint and allocate the tensor. Loading the correct shard from the checkpoint - happens automatically via the checkpoint loading hook added in step 2c. -""" +"""Sharding config definitions for the inference optimizer.""" import math import operator from abc import ABC, abstractmethod -from collections import defaultdict from enum import IntEnum from functools import partial -from typing import Callable, DefaultDict, Dict, List, Literal, Optional, Set +from typing import Callable, Dict, List, Literal, Optional import torch import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field from torch.fx import GraphModule, Node -from ...utils.logger import ad_logger -from ...utils.node_utils import ( - extract_param_names_from_lin_node, - identify_regions_between_residuals, - is_linear_op, - is_op, - num_users_of_weight_node, -) -from ...utils.quantization_utils import QuantizationImpl -from .._graph import canonicalize_graph - - -class SplitDimension(IntEnum): - """Enum for tensor split dimensions in sharding.""" - - ROW = 0 # Split along rows (first dimension) - COLUMN = 1 # Split along columns (second dimension) - - -class ShardingTransformInfo(BaseModel, ABC): - """Abstract base class for transformation configurations.""" - - model_config = ConfigDict(frozen=True) # Makes the model immutable and hashable - - target_node: str - rank: int - world_size: int - - def validate(self, gm: GraphModule = None, node: Node = None) -> bool: - """ - Validate whether the transformation is valid. - Execute right before applying the transformation. - """ - return True - - @abstractmethod - def apply(self, gm: GraphModule, node: Node) -> None: - """Apply the transformation to the graph module. - - This method must be implemented by each transformation class. - """ - pass - - def check_and_apply(self, gm: GraphModule, node: Node) -> None: - """Check if the transformation is valid and apply it if it is.""" - if not self.validate(gm, node): - ad_logger.warning(f"Skipping invalid transformation {self}.") - return - self.apply(gm, node) - - -class TPShardingInfo(ShardingTransformInfo): - """Configuration for TP sharding transformations.""" - - split_dim: SplitDimension - dist_op: Optional[Literal["all_reduce", "all_gather"]] = None - min_local_shape: int = 1 - - def validate(self, gm: GraphModule = None, node: Node = None) -> bool: - """Validate the transformation configuration.""" - if self.dist_op is not None: - if self.split_dim == SplitDimension.ROW: - if self.dist_op == "all_reduce": - ad_logger.warning( - f"Row split is only supported for all_gather. Skipping {self}." - ) - return False - if self.split_dim == SplitDimension.COLUMN: - if self.dist_op == "all_gather": - ad_logger.warning( - f"Column split is only supported for all_reduce. Skipping {self}." - ) - return False - return True - - def apply(self, gm: GraphModule, node: Node) -> None: - """Apply TP sharding transformation to the graph module.""" - - _insert_sharded_matmul( - gm=gm, - node=node, - dim=self.split_dim.value, - rank=self.rank, - world_size=self.world_size, - add_dist=self.dist_op is not None, - min_local_shape=self.min_local_shape, - ) - - -class BMMShardingInfo(ShardingTransformInfo): - """Configuration for BMM sharding transformations.""" - - rank: int - world_size: int - start_idx: int - end_idx: int - - def validate(self, gm: GraphModule = None, node: Node = None) -> bool: - """Validate the transformation configuration.""" - if not is_op(node, torch.ops.aten.bmm): - ad_logger.warning(f"BMM sharding is only supported for BMM nodes. Skipping {self}.") - return False - - # Get the input tensors - lhs_tensor = node.args[0] - rhs_tensor = node.args[1] - - # Check batch sizes from meta information - lhs_batch_size = lhs_tensor.meta["val"].shape[0] - rhs_batch_size = rhs_tensor.meta["val"].shape[0] - - assert lhs_batch_size == rhs_batch_size, "Batch sizes of both tensors must match" - bmm_batch_size = lhs_batch_size - - # Check if the distribution is balanced - remainder = bmm_batch_size % self.world_size - - # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. - if remainder: - ad_logger.warning( - f"BMM batch size {bmm_batch_size} is not divisible by world size {self.world_size}. " - f"This will result in uneven distribution of work across devices. Skipping." - ) - return False - return True - - def apply(self, gm: GraphModule, node: Node) -> None: - """Apply BMM sharding transformation to the graph module.""" - - def handle_tensor( - bmm_node: Node, tensor_node: Node, arg_idx: int, start_idx: int, end_idx: int - ): - """Unified helper function to shard either a parameter tensor or a dynamic tensor. - - Args: - bmm_node: The BMM node that is being processed - tensor_node: The input tensor node to shard - arg_idx: The argument index of the tensor in the BMM node - start_idx: Start index for sharding - end_idx: End index for sharding - """ - - # Define slice function for the sharding - def slice_tensor(t: torch.Tensor) -> torch.Tensor: - return t[start_idx:end_idx] - - if tensor_node.op == "get_attr": - # Handle parameter tensor - weight_key = tensor_node.target - modname, _, param_name = weight_key.rpartition(".") - param = gm.get_parameter(weight_key) - - # Update the parameter with its shard - param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True) - gm.get_submodule(modname).register_parameter(param_name, param_new) - - # Register load state dict hook - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, - f_split=slice_tensor, - param_key=weight_key, - param_shape=param_new.shape, - ) - ) - else: - # Handle dynamic tensor - with gm.graph.inserting_before(bmm_node): - tensor_slice = gm.graph.call_function( - torch.ops.aten.slice.Tensor, args=(tensor_node, 0, start_idx, end_idx, 1) - ) - # Update BMM node to use the sliced tensor - bmm_node.update_arg(arg_idx, tensor_slice) - - # Get the input tensors - lhs_tensor = node.args[0] - rhs_tensor = node.args[1] - # Handle both tensors - handle_tensor(node, lhs_tensor, 0, self.start_idx, self.end_idx) - handle_tensor(node, rhs_tensor, 1, self.start_idx, self.end_idx) - - # Add all_gather node after BMM to collect results - with gm.graph.inserting_after(node): - gather_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_gather, - args=(node, 0), # Gather along batch dimension (0) - ) - node.replace_all_uses_with(gather_node) - gather_node.replace_input_with(gather_node, node) - - -class EPShardingInfo(ShardingTransformInfo): - """Configuration for EP sharding transformations.""" - - rank: int - world_size: int - - def validate(self, gm: GraphModule = None, node: Node = None) -> bool: - """Validate the transformation configuration.""" - if not is_op( - node, - ( - torch.ops.auto_deploy.torch_moe, - torch.ops.auto_deploy.torch_quant_fp8_moe, - torch.ops.auto_deploy.torch_quant_fp4_moe, - ), - ): - ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.") - return False - return True - - def apply(self, gm: GraphModule, node: Node) -> None: - """Apply EP sharding transformation to the graph module.""" - _insert_sharded_moe(gm, node, self.rank, self.world_size) - - -class ShardingConfig(BaseModel): - """Configuration for sharding the model.""" - - tp_transforms: List[TPShardingInfo] = Field(default_factory=list) - bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) - ep_transforms: List[EPShardingInfo] = Field(default_factory=list) - - -def sharding_transform_executor(gm: GraphModule, sharding_config: ShardingConfig) -> None: - """Apply transformations to the graph module. - - Args: - gm: Graph module to apply transformations to - sharding_config: Transformation configuration containing list of transformations to apply - """ - # create a node dict for faster lookup - node_dict = {n.name: n for n in gm.graph.nodes} - - def check_and_apply(transform: ShardingTransformInfo) -> None: - if transform.target_node is None or transform.target_node not in node_dict: - ad_logger.warning( - f"Skipping transformation {transform} because target node " - + f"{transform.target_node} not found in graph" - ) - return - transform.check_and_apply(gm, node_dict[transform.target_node]) - - for tp_transform in sharding_config.tp_transforms: - check_and_apply(tp_transform) - for bmm_transform in sharding_config.bmm_transforms: - check_and_apply(bmm_transform) - for ep_transform in sharding_config.ep_transforms: - check_and_apply(ep_transform) - - # canonicalize and return - gm = canonicalize_graph(gm) - ad_logger.debug("After applying sharding transformations: " + str(gm)) +from ..utils.logger import ad_logger +from .node_utils import extract_param_names_from_lin_node, is_op, num_users_of_weight_node +from .quantization_utils import QuantizationImpl def _load_hook( @@ -446,234 +179,100 @@ def _insert_sharded_matmul( dist_node.replace_input_with(dist_node, node) -def _append_simple_shard( - nodes_linear: Dict[Node, List[Node]], - rank: int, - world_size: int, - sharding_config: ShardingConfig, -) -> None: - # for every linear node: - # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] - for node_group in nodes_linear.values(): - for n in node_group: - tp_shards.append( - TPShardingInfo( - target_node=n.name, - split_dim=SplitDimension.ROW, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, - ) - ) - sharding_config.tp_transforms.extend(tp_shards) +class SplitDimension(IntEnum): + """Enum for tensor split dimensions in sharding.""" + + ROW = 0 # Split along rows (first dimension) + COLUMN = 1 # Split along columns (second dimension) -def detect_column_row_shard( - gm: GraphModule, - rank: int, - world_size: int, - sharding_config: ShardingConfig, - simple_shard_only: bool = False, -) -> None: - """A transformation to apply sharding to the model following tensor parallelism. +class ShardingTransformInfo(BaseModel, ABC): + """Abstract base class for transformation configurations.""" - The transformation is based on the following steps: + model_config = ConfigDict(frozen=True) # Makes the model immutable and hashable - 1. Identify boundary nodes between residual nodes to identify shardable regions. - 2. Identify the GEMM nodes that can be sharded - 3. Trace through the subgraph using DFS/BFS between each pair of boundary nodes - 4. Account for each node in the trace to ensure the op is correct even after sharding. This is - necessary to ensure that the sharding is correct and we need to be able to account for - **all** nodes in the subgraph. The subgraph here is defined as the region between the first - linear node to the last linear node of an identified sharding region. - # 5. Shard the GEMM nodes or skip accordingly. + target_node: str + rank: int + world_size: int - min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism - splitting, e.g., the individual heads into smaller shards. - """ - ad_logger.debug("Before sharding graph: " + str(gm)) + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + """ + Validate whether the transformation is valid. + Execute right before applying the transformation. + """ + return True - if world_size < 2: - ad_logger.info("Skipping sharding for single device") - return + @abstractmethod + def apply(self, gm: GraphModule, node: Node) -> None: + """Apply the transformation to the graph module. - assert isinstance(gm, GraphModule), "Expecting GraphModule" + This method must be implemented by each transformation class. + """ + pass - # find boundary nodes of regions we want to shard - boundary_nodes = identify_regions_between_residuals(gm) + def check_and_apply(self, gm: GraphModule, node: Node) -> bool: + """ + Check if the transformation is valid and apply it if it is. + Return True if the transformation is applied, False otherwise. + """ + if not self.validate(gm, node): + ad_logger.warning(f"Skipping invalid transformation {self}.") + return False + self.apply(gm, node) + return True - # TODO: continue updating these lists - # pointwise ops that don't affect the sharder - pointwise_ops = { - torch.ops.aten.gelu, - torch.ops.aten.leaky_relu, - torch.ops.aten.mul, - torch.ops.aten.relu, - torch.ops.aten.sigmoid, - torch.ops.aten.silu, - torch.ops.aten.tanh, - torch.ops.aten.contiguous, - } - # acceptable attention nodes between sharded GEMMs - shardable_attention_nodes = { - torch.ops.auto_deploy.torch_attention_sdpa, - torch.ops.auto_deploy.torch_attention_grouped_sdpa, - torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, - } +class TPShardingInfo(ShardingTransformInfo): + """Configuration for TP sharding transformations.""" - # This is a heuristic. Basically, we assume those are okay to shard if we also encounter an - # attention node because we know that those ops must be compatible with the attention op. Now - # since the attention op is shardable, we will assume those are as well if used in conjunction - # with the attention op. - shardable_nodes_with_attention = { - torch.ops.aten.view, - torch.ops.aten.reshape, - torch.ops.auto_deploy.flashinfer_rope, - operator.getitem, - } + split_dim: SplitDimension + dist_op: Optional[Literal["all_reduce", "all_gather"]] = None + min_local_shape: int = 1 - # let's look at linear nodes we can identify between pairs of boundary nodes - # There is three potential cases we can handle: - # 1. No linear nodes: - # --> just continue - # 2. Two groups of linear nodes and we can account for all to the view nodes: - # --> row_split (dim 0) 1st group + check for supported nodes + - # col_split (dim 1) 2nd group + all_reduce output of 2nd group - # 3. Linear nodes that are not in two groups or we cannot account for all nodes: - # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) output - num_shards = 0 - for n_start, n_end in zip(boundary_nodes[:-1], boundary_nodes[1:]): - # we iterate through all nodes between the two boundary nodes and store linear nodes - # sorted by their input activation node. We also store remaining nodes. - nodes_linear: DefaultDict[Node, List[Node]] = defaultdict(list) - attention_nodes: Set[Node] = set() - attention_related_nodes: Set[Node] = set() - unaccounted_nodes: Set[Node] = set() - current_node = n_start - while current_node != n_end: - if is_linear_op(current_node, include_quantization=True): - nodes_linear[current_node.args[0]].append(current_node) - elif is_op(current_node, shardable_attention_nodes): - attention_nodes.add(current_node) - elif is_op(current_node, shardable_nodes_with_attention): - attention_related_nodes.add(current_node) - elif not is_op(current_node, pointwise_ops): - unaccounted_nodes.add(current_node) - current_node = current_node.next - assert current_node, "Could not identify next node" - - # nothing to shard - if len(nodes_linear) == 0: - continue - - num_shards += 1 - - if simple_shard_only: - ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - continue - - # simple shard when we have != 2 groups of linear nodes - if len(nodes_linear) != 2: - ad_logger.debug(f"Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - continue - - # let's look at the unnacounted nodes. They are okay as long as they fall before the - # first linear node or after the last linear node, i.e., outside the sharded region - lin_nodes_flat: Set[Node] = {n for group in nodes_linear.values() for n in group} - lin_nodes_passed: Set[Node] = set() - current_node = n_start - while current_node != n_end: - # check if this is another linear node - if current_node in lin_nodes_flat: - lin_nodes_passed.add(current_node) - - # check if we are OUTSIDE sharded region - if len(lin_nodes_passed) == 0 or lin_nodes_passed == lin_nodes_flat: - # remove node from unaccounted nodes since we are outside and it doesn't matter - unaccounted_nodes.discard(current_node) - attention_related_nodes.discard(current_node) - attention_nodes.discard(current_node) - - current_node = current_node.next - - # let's post-process the attention-related nodes - # we can disregard them if we also see attention nodes and we assume they are compatible - if len(attention_nodes) > 0: - attention_related_nodes.clear() - - # check if any unaccounted nodes are left. If so, do a simply shard - if unaccounted_nodes or attention_related_nodes: - ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - continue - - # If we can account for all sharded nodes, we can do a two-way shard - # --> row_split (dim 0) + col_split (dim 1) + all_reduce - - # check if we are sharding the attention block - if attention_nodes: - if len(attention_nodes) > 1: - # Column-row shard boundary region detection is probably wrong - there should be - # only one attention operation. Fall back to simple shard. - ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - continue - # Extract head dimension. We cannot shard below the head_dim size. - # Assume that head_dim is the last (innermost) dimension of the tensor - min_local_shape = attention_nodes.pop().meta["val"].shape[-1] - else: - min_local_shape = 1 - for i, group in enumerate(nodes_linear.values()): - for n in group: - if i > 0: - dist_op = "all_reduce" - else: - dist_op = None - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=n.name, - split_dim=i, - rank=rank, - world_size=world_size, - dist_op=dist_op, - min_local_shape=min_local_shape, + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + """Validate the transformation configuration.""" + if self.dist_op is not None: + if self.split_dim == SplitDimension.ROW: + if self.dist_op == "all_reduce": + ad_logger.warning( + f"Row split is only supported for all_gather. Skipping {self}." ) - ) + return False + if self.split_dim == SplitDimension.COLUMN: + if self.dist_op == "all_gather": + ad_logger.warning( + f"Column split is only supported for all_reduce. Skipping {self}." + ) + return False + return True - ad_logger.info(f"Found {num_shards} TP shards") + def apply(self, gm: GraphModule, node: Node) -> None: + """Apply TP sharding transformation to the graph module.""" + + _insert_sharded_matmul( + gm=gm, + node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + ) -def detect_dp_bmm_shard( - gm: GraphModule, rank: int, world_size: int, sharding_config: ShardingConfig -) -> None: - """A transformation to apply sharding to batched matrix multiplications in the graph. +class BMMShardingInfo(ShardingTransformInfo): + """Configuration for BMM sharding transformations.""" - We'll shard the BMM nodes by slicing the batch dimension of input tensors into world_size number of slices. - After sharding each BMM node, we'll insert an all_gather node to gather the results across the different devices. - This transformation handles any combination of tensor types for both inputs to the BMM operation. + rank: int + world_size: int + start_idx: int + end_idx: int - We'll also assume that the inputs to BMM are broadcasted across the devices already. - """ - ad_logger.debug("Before sharding graph: " + str(gm)) - - if world_size < 2: - ad_logger.info("Skipping sharding for single device") - return - - assert isinstance(gm, GraphModule), "Expecting GraphModule" - - num_bmm_shards = 0 - - for node in gm.graph.nodes: - if not is_op(node, {torch.ops.aten.bmm}): - continue - - ad_logger.debug(f"Found BMM node: {node}") + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + """Validate the transformation configuration.""" + if not is_op(node, torch.ops.aten.bmm): + ad_logger.warning(f"BMM sharding is only supported for BMM nodes. Skipping {self}.") + return False # Get the input tensors lhs_tensor = node.args[0] @@ -686,79 +285,81 @@ def detect_dp_bmm_shard( assert lhs_batch_size == rhs_batch_size, "Batch sizes of both tensors must match" bmm_batch_size = lhs_batch_size - # Calculate balanced distribution - base_size = bmm_batch_size // world_size - remainder = bmm_batch_size % world_size + # Check if the distribution is balanced + remainder = bmm_batch_size % self.world_size # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. if remainder: ad_logger.warning( - f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. " + f"BMM batch size {bmm_batch_size} is not divisible by world size {self.world_size}. " f"This will result in uneven distribution of work across devices. Skipping." ) - continue + return False + return True - # Calculate start and end indices for this rank - if rank < remainder: - start_idx = rank * (base_size + 1) - end_idx = start_idx + base_size + 1 - else: - start_idx = remainder + rank * base_size - end_idx = start_idx + base_size + def apply(self, gm: GraphModule, node: Node) -> None: + """Apply BMM sharding transformation to the graph module.""" - sharding_config.bmm_transforms.append( - BMMShardingInfo( - target_node=node.name, - rank=rank, - world_size=world_size, - start_idx=start_idx, - end_idx=end_idx, - ) - ) - ad_logger.debug( - f"Sharding BMM for rank {rank}: batch_size={bmm_batch_size}, start_idx={start_idx}, end_idx={end_idx}" - ) - - num_bmm_shards += 1 - - # Canonicalize and return - if num_bmm_shards: - gm = canonicalize_graph(gm) - ad_logger.debug("After sharding BMM: " + str(gm)) - ad_logger.info(f"Found {num_bmm_shards} BMM shards") - - -def detect_ep_shard( - gm: GraphModule, rank: int, world_size: int, sharding_config: ShardingConfig -) -> None: - ad_logger.debug("Before sharding graph: " + str(gm)) - - if world_size < 2: - ad_logger.info("Skipping sharding for single device") - return - - assert isinstance(gm, GraphModule), "Expecting GraphModule" - num_moe_patterns = 0 - for node in list(gm.graph.nodes): - if not is_op( - node, - ( - torch.ops.auto_deploy.torch_moe, - torch.ops.auto_deploy.torch_quant_fp8_moe, - torch.ops.auto_deploy.torch_quant_fp4_moe, - ), + def handle_tensor( + bmm_node: Node, tensor_node: Node, arg_idx: int, start_idx: int, end_idx: int ): - continue - sharding_config.ep_transforms.append( - EPShardingInfo( - target_node=node.name, - rank=rank, - world_size=world_size, - ) - ) - num_moe_patterns += 1 + """Unified helper function to shard either a parameter tensor or a dynamic tensor. - ad_logger.info(f"Found {num_moe_patterns} MoE patterns") + Args: + bmm_node: The BMM node that is being processed + tensor_node: The input tensor node to shard + arg_idx: The argument index of the tensor in the BMM node + start_idx: Start index for sharding + end_idx: End index for sharding + """ + + # Define slice function for the sharding + def slice_tensor(t: torch.Tensor) -> torch.Tensor: + return t[start_idx:end_idx] + + if tensor_node.op == "get_attr": + # Handle parameter tensor + weight_key = tensor_node.target + modname, _, param_name = weight_key.rpartition(".") + param = gm.get_parameter(weight_key) + + # Update the parameter with its shard + param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True) + gm.get_submodule(modname).register_parameter(param_name, param_new) + + # Register load state dict hook + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=slice_tensor, + param_key=weight_key, + param_shape=param_new.shape, + ) + ) + else: + # Handle dynamic tensor + with gm.graph.inserting_before(bmm_node): + tensor_slice = gm.graph.call_function( + torch.ops.aten.slice.Tensor, args=(tensor_node, 0, start_idx, end_idx, 1) + ) + # Update BMM node to use the sliced tensor + bmm_node.update_arg(arg_idx, tensor_slice) + + # Get the input tensors + lhs_tensor = node.args[0] + rhs_tensor = node.args[1] + # Handle both tensors + handle_tensor(node, lhs_tensor, 0, self.start_idx, self.end_idx) + handle_tensor(node, rhs_tensor, 1, self.start_idx, self.end_idx) + + # Add all_gather node after BMM to collect results + with gm.graph.inserting_after(node): + gather_node = gm.graph.call_function( + torch.ops.auto_deploy.torch_dist_all_gather, + args=(node, 0), # Gather along batch dimension (0) + ) + node.replace_all_uses_with(gather_node) + gather_node.replace_input_with(gather_node, node) def _insert_sharded_moe( @@ -846,3 +447,36 @@ def _insert_sharded_moe( ) node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) + + +class EPShardingInfo(ShardingTransformInfo): + """Configuration for EP sharding transformations.""" + + rank: int + world_size: int + + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + """Validate the transformation configuration.""" + if not is_op( + node, + ( + torch.ops.auto_deploy.torch_moe, + torch.ops.auto_deploy.torch_quant_fp8_moe, + torch.ops.auto_deploy.torch_quant_fp4_moe, + ), + ): + ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.") + return False + return True + + def apply(self, gm: GraphModule, node: Node) -> None: + """Apply EP sharding transformation to the graph module.""" + _insert_sharded_moe(gm, node, self.rank, self.world_size) + + +class ShardingConfig(BaseModel): + """Configuration for sharding the model.""" + + tp_transforms: List[TPShardingInfo] = Field(default_factory=list) + bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) + ep_transforms: List[EPShardingInfo] = Field(default_factory=list) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 89d866ee9a..da4df91f69 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -25,8 +25,8 @@ class DynamicTensorSpec: """ input_idx: int dim_idx: int - gen_tuning_buckets: Union[Tuple[int], Callable] - map_to_tuning_buckets: Callable + gen_tuning_buckets: Union[Tuple[int], Callable] = () + map_to_tuning_buckets: Callable = lambda x: x @dataclass(slots=True, unsafe_hash=True) @@ -43,7 +43,7 @@ class ConstraintSpec: infer_shape: Callable -@dataclass(kw_only=True, unsafe_hash=True) +@dataclass(kw_only=True) class TuningConfig: """Configuration for autotuning. @@ -81,9 +81,15 @@ class TuningConfig: ... ), ... ) ... ) + tune_max_num_tokens (int): The maximum saturation number of tokens to be tuned. + During the inference, the input tensor will be saturated with the same value. Or if + any value is provided to the choose_one function, the input tensor will be saturated + with the provided value. + If not provided, the autotuner will not consider the max num tokens. """ dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = () + tune_max_num_tokens: int = None @dataclass(unsafe_hash=True) @@ -139,12 +145,13 @@ class TunableRunner(ABC): @abstractmethod def get_valid_tactics(self, inputs: List[torch.Tensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[Any]: """One tactic corresponding to one cuda kernel normally, but how to interpret the meaning of tactic is pure internal details of the runner. - The autotuner will just pass the tactic value to the forward w/o any knowledge on what the tactic - means. + The autotuner will just pass the tactic value to the forward w/o. any knowledge on what the tactic + means. User can choose to implement their own types of tactic for flexibility, such as using a dict-typed + to represent a collection of named configs. tactic==-1 has special meaning, means the fallback kernel which should be able to implement any shapes This fallback tactic is needed for 2 reasons: @@ -166,15 +173,17 @@ class TunableRunner(ABC): /, # tensors are position only inputs: List[torch.Tensor], *, # all others are keyword args only - tactic: int = -1, - do_preparation: bool = False) -> Any: + tactic: Any = -1, + do_preparation: bool = False, + **kwargs) -> Any: """Forward pass for tunable runners. Args: inputs: List of input tensors (position-only argument) - tactic: Integer ID specifying which implementation tactic to use. - -1 (default) represents the fallback tactic that must be implemented - to handle any input shapes when autotuning is disabled. + tactic: A arbitrary type that represents a specific kernel config. + For instance, it can be an integer number that specifies the unique ID of the implementation tactic to use. + -1 (default) represents the fallback tactic that must be implemented + to handle any input shapes when autotuning is disabled. do_preparation: When True, allows one-time setup operations to be performed before tactic evaluation begins. These operations are excluded from the performance measurements during autotuning. Notice that @@ -182,7 +191,7 @@ class TunableRunner(ABC): and can be accessed by the following forward calls. Returns: - Any: Output of the forward pass + Any: Output of the forward pass. """ raise NotImplementedError @@ -277,6 +286,7 @@ class AutoTuner: self.warmup = warmup self.stream_delay_micro_secs = stream_delay_micro_secs self.profiling_cache = {} + self.registered_tuning_configs = {} self.is_tuning_mode = False # Add statistics tracking @@ -296,7 +306,7 @@ class AutoTuner: runners: List[TunableRunner], input_shapes: Tuple[torch.Size], tuning_config: TuningConfig, - ) -> Tuple[bool, int, int, OptimizationProfile]: + ) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]: """Search for cached profiling results matching the current configuration. Args: @@ -316,9 +326,14 @@ class AutoTuner: return False, 0, -1, None - def choose_one(self, custom_op: str, runners: List[TunableRunner], - tuning_config: TuningConfig, inputs: List[torch.Tensor], - **kwargs) -> Tuple[TunableRunner, int]: + def choose_one( + self, + custom_op: str, + runners: List[TunableRunner], + tuning_config: TuningConfig, + inputs: List[torch.Tensor], + **kwargs, + ) -> Tuple: """Choose the best runner and tactic combination through performance profiling. Args: @@ -329,9 +344,10 @@ class AutoTuner: **kwargs: Arbitrary keyword arguments, will be passed to get_valid_tactics and forward method of each runner Returns: - Tuple[TunableRunner, int]: A tuple containing: + Tuple: A tuple containing: - The selected runner implementation - The best tactic ID for that runner (-1 if using fallback) + - The best config for that runner (if configs is not empty) Note: The method profiles different implementations and tactics to find the @@ -342,26 +358,29 @@ class AutoTuner: """ input_shapes = tuple(self._get_input_sizes(inputs)) - # Early return if it's not tuning, use cache found one or fallback one if not self.is_tuning_mode: - is_cache_hit, runner_id, tactic, stored_profile = self.search_cache( + is_cache_hit, best_runner_id, best_tactic, stored_profile = self.search_cache( custom_op, runners, input_shapes, tuning_config) - runner = runners[runner_id] + best_runner = runners[best_runner_id] # TODO: check the stored runner and tactic can implement this shape here # Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf. - if not is_cache_hit and len(self.profiling_cache) > 0: - # Only log once for each custom op and only when cache is not empty + + # Record the cache miss config. + # Expect no cache miss in inference. Thus, any cache miss should be recorded. + if not is_cache_hit: logger.warning_once( f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}", key=(custom_op)) - return runner, tactic + + return (best_runner, best_tactic) assert len(runners) > 0, "At least one runner is required" assert all([isinstance(r, TunableRunner) for r in runners]), \ "All Given runners must be subclass of TunableRunner" profiles = self._optimization_profiles(tuning_config, inputs) + # Record the total configs to try self.stats.tuned_op_total_configs[custom_op] = len(profiles) @@ -369,63 +388,26 @@ class AutoTuner: for p in profiles: tensors = self._prepare_input_tensors(p, inputs) - is_cache_hit, runner_id, tactic, _ = self.search_cache( - custom_op, runners, p.get_opt_shapes(), tuning_config) + is_cache_hit, *_ = self.search_cache(custom_op, runners, + p.get_opt_shapes(), + tuning_config) if not is_cache_hit: - min_time = float('inf') # Initialize runner and tactic as None in case of no valid tactic or runners are found - runner_id, tactic = None, None - for r_id, r in enumerate(runners): - # TODO: use FakeTensor here. - valid_tactics = r.get_valid_tactics(tensors, p) - runner_arg_names = { - p.name - for p in inspect.signature( - r.forward).parameters.values() - } - if "do_preparation" in runner_arg_names and len( - valid_tactics) > 0: - r(tensors, tactic=-1, do_preparation=True, **kwargs) - for tac in valid_tactics: - try: - time_measured = self._profile_single_kernel( - r, tensors, tac, **kwargs) - except Exception as e: - shapes = self._get_input_sizes(tensors) - - logger.warning( - f"[Autotuner] Failed when profiling runner={r}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details." - ) - logger.debug(f"[Autotuner] Exception captured: {e}") - - # Record the failed profiling combinations - new_tuning_failure_occured = True - if custom_op not in self.stats.failed_profiling_count: - self.stats.failed_profiling_count[ - custom_op] = set() - self.stats.failed_profiling_count[custom_op].add( - AutoTuner._get_cache_key( - custom_op, r, p.get_opt_shapes(), - tuning_config)) - - # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics - # or some runtime error occurs during profiling. - time_measured = float('inf') - if time_measured < min_time: - min_time = time_measured - runner_id, tactic = r_id, tac - if runner_id is not None: + best_runner_id, best_tactic, has_tuning_failure_occured = self._profile_runners( + custom_op, runners, tensors, p, tuning_config, **kwargs) + if best_runner_id is not None: # At least one valid (runner, tactic) pair is found cache_key = AutoTuner._get_cache_key( - custom_op, runners[runner_id], p.get_opt_shapes(), + custom_op, runners[best_runner_id], p.get_opt_shapes(), tuning_config) # inspect call stack - self.profiling_cache[cache_key] = (runner_id, tactic, p) + self.profiling_cache[cache_key] = (best_runner_id, + best_tactic, p) self.stats.tuned_op_successful_configs[ custom_op] = self.stats.tuned_op_successful_configs.get( custom_op, 0) + 1 logger.debug( - f"[Autotuner] Profiling runner={runners[runner_id]}, tactic={tactic} for cache_key={cache_key}." + f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}." ) else: logger.warning( @@ -434,6 +416,7 @@ class AutoTuner: f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op " f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash." ) + new_tuning_failure_occured = new_tuning_failure_occured or has_tuning_failure_occured # If failed profiling tactics occurs, log the error. if new_tuning_failure_occured: @@ -450,7 +433,64 @@ class AutoTuner: _, runner_id, tactic, _ = self.search_cache(custom_op, runners, input_shapes, tuning_config) - return runners[runner_id], tactic + return (runners[runner_id], tactic) + + def _profile_runners( + self, + custom_op: str, + runners: List[TunableRunner], + input_tensors: List[torch.Tensor], + profile: OptimizationProfile, + tuning_config: TuningConfig, + **kwargs, + ) -> float: + min_time = float('inf') + has_tuning_failure_occured = False + best_runner_id, best_tactic = None, None + for runner_id, runner in enumerate(runners): + # TODO: use FakeTensor here. + runner_arg_names = { + p.name + for p in inspect.signature(runner.forward).parameters.values() + } + valid_tactics = runner.get_valid_tactics(input_tensors, profile) + if "do_preparation" in runner_arg_names and len(valid_tactics) > 0: + runner( + input_tensors, + tactic=-1, + do_preparation=True, + **kwargs, + ) + + for tac in valid_tactics: + try: + time_measured = self._profile_single_kernel( + runner, input_tensors, tac, **kwargs) + except Exception as e: + # Handle None tensors for optional inputs + shapes = self._get_input_sizes(input_tensors) + logger.warning( + f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details." + ) + logger.debug(f"[Autotuner] Exception captured: {e}") + + # Record the failed profiling combinations + if custom_op not in self.stats.failed_profiling_count: + self.stats.failed_profiling_count[custom_op] = set() + self.stats.failed_profiling_count[custom_op].add( + AutoTuner._get_cache_key(custom_op, runner, + profile.get_opt_shapes(), + tuning_config)) + + # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics + # or some runtime error occurs during profiling. + time_measured = float('inf') + has_tuning_failure_occured = True + if time_measured < min_time: + min_time = time_measured + best_runner_id, best_tactic = runner_id, tac + + return best_runner_id, best_tactic, has_tuning_failure_occured def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: @@ -462,15 +502,19 @@ class AutoTuner: return sizes - def _profile_single_kernel(self, runner: TunableRunner, - inputs: List[torch.Tensor], tactic: int, - **kwargs) -> float: + def _profile_single_kernel( + self, + runner: TunableRunner, + inputs: List[torch.Tensor], + tactic: Any, + **kwargs, + ) -> float: """Profile a single kernel implementation for performance measurement. Args: runner (TunableRunner): The runner implementation to profile inputs (List[torch.Tensor]): Input tensors for the kernel - tactic (int): Tactic ID to use for this profiling run + tactic (Any): Tactic to use for this profiling run Returns: Average execution time in milliseconds @@ -503,7 +547,7 @@ class AutoTuner: shapes = self._get_input_sizes(inputs) logger.debug( - f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time}ms." + f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms." ) return avg_time @@ -541,10 +585,23 @@ class AutoTuner: assert inspect.isfunction(spec.gen_tuning_buckets) or isinstance(spec.gen_tuning_buckets, (list, tuple)), \ "The given dynamic dimension must provide a opt value generation function or a list of opt values" if inspect.isfunction(spec.gen_tuning_buckets): - opt_shapes = spec.gen_tuning_buckets( - base_profile.shapes[spec.input_idx][spec.dim_idx].val) + if tuning_config.tune_max_num_tokens is None: + # Use the current input size as the opt value + opt_shapes = spec.gen_tuning_buckets( + base_profile.shapes[spec.input_idx][spec.dim_idx].val) + else: + # Use the tune_max_num_tokens as the opt value + opt_shapes = spec.gen_tuning_buckets( + tuning_config.tune_max_num_tokens) else: + # Default values is an empty tuple, means that user does not want to tune this dimension. opt_shapes = spec.gen_tuning_buckets + # Add the current input value as one of the opt values + opt_shapes = set(opt_shapes) + opt_shapes.add( + spec.map_to_tuning_buckets( + base_profile.shapes[spec.input_idx][spec.dim_idx].val)) + opt_shapes = sorted(list(opt_shapes)) opt_shapes_max = tuple(opt_shapes[1:]) + (float('inf'), ) opt_shapes_max = { v1: v2 @@ -570,6 +627,8 @@ class AutoTuner: for spec in tuning_config.constraint_specs: min_value = opt_value = max_value = spec.infer_shape( p.get_opt_shapes()) + if p.shapes[spec.input_idx] == [StaticDim(0)]: + continue p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim( min_value, opt_value, max_value) generated_profiles.append(p) @@ -578,8 +637,13 @@ class AutoTuner: @classmethod @lru_cache(maxsize=None) - def _find_nearest_profile(cls, shapes: Tuple[torch.Size], - tuning_config: TuningConfig) -> Tuple: + def _find_nearest_profile( + cls, + shapes: Tuple[torch.Size], + dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...], + constraint_specs: Tuple[ConstraintSpec, ...], + tune_max_num_tokens: int = None, + ) -> Tuple: """Find the nearest optimization profile for given inputs User can define their own nearest profile generation method to reduce the host overhead. @@ -594,13 +658,20 @@ class AutoTuner: """ base_profile = list(list(shape) for shape in shapes) - for spec in tuning_config.dynamic_tensor_specs: + for spec in dynamic_tensor_specs: base_profile[spec.input_idx][ spec.dim_idx] = spec.map_to_tuning_buckets( base_profile[spec.input_idx][spec.dim_idx]) + if tune_max_num_tokens is not None: + base_profile[spec.input_idx][spec.dim_idx] = min( + base_profile[spec.input_idx][spec.dim_idx], + tune_max_num_tokens) + # associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile - for spec in tuning_config.constraint_specs: + for spec in constraint_specs: + if base_profile[spec.input_idx] == [0]: + continue base_profile[spec.input_idx][spec.dim_idx] = -1 return tuple(tuple(shape) for shape in base_profile) @@ -614,7 +685,10 @@ class AutoTuner: tuning_config: TuningConfig, ) -> Tuple: return (custom_op, runner.__class__.__name__, hash(runner), - cls._find_nearest_profile(input_shapes, tuning_config)) + cls._find_nearest_profile(input_shapes, + tuning_config.dynamic_tensor_specs, + tuning_config.constraint_specs, + tuning_config.tune_max_num_tokens)) def _create_tensor_like(self, origin_tensor: torch.Tensor, dims: List[Dim]) -> torch.Tensor: @@ -672,5 +746,6 @@ class AutoTuner: f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))" ) for key, value in self.profiling_cache.items(): - runner_id, tactic, _ = value - logger.debug(f"[Autotuner] {key}: ({runner_id}, {tactic})") + runner_id, tactic, profile = value + logger.debug( + f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic})") diff --git a/tensorrt_llm/_torch/compilation/backend.py b/tensorrt_llm/_torch/compilation/backend.py index f6e7ae6490..02e2ae8fe5 100644 --- a/tensorrt_llm/_torch/compilation/backend.py +++ b/tensorrt_llm/_torch/compilation/backend.py @@ -37,7 +37,7 @@ class Backend: enable_inductor=True, enable_userbuffers=False, enable_piecewise_cuda_graph: bool = False, - cuda_graph_batch_sizes: Optional[List[int]] = None, + capture_num_tokens: Optional[List[int]] = None, max_num_streams: int = 1, ) -> None: super().__init__() @@ -48,14 +48,12 @@ class Backend: self.custom_passes = Backend.get_custom_pass(enable_userbuffers) self.rank = tensorrt_llm.mpi_rank() self.enable_inductor = enable_inductor - self.cuda_graph_batch_sizes = (cuda_graph_batch_sizes - if cuda_graph_batch_sizes is not None - else []) + self.capture_num_tokens = capture_num_tokens or [] self.piecewise_cuda_graph = enable_piecewise_cuda_graph self.no_optimization = False # We only need to create aux streams. self.aux_streams = Backend.Streams( - [torch.cuda.Stream() for i in range(max_num_streams - 1)]) + [torch.cuda.Stream() for _ in range(max_num_streams - 1)]) self.events = Backend.Events() inductor_config.enable_auto_functionalized_v2 = False @@ -125,7 +123,7 @@ class Backend: example_inputs, self.enable_inductor, self.input_num_tokens, - self.cuda_graph_batch_sizes, + self.capture_num_tokens, self._graph_pool_handle, len(self.aux_streams) + 1, ) diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index f7624e6b16..c83644eed2 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -14,8 +14,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag, make_weak_ref) from .multi_stream.auto_multi_stream import multi_stream_schedule -from .utils import (get_enable_piecewise_cuda_graph_capture_flag, - is_call_function) +from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function class PiecewiseInterpreter(Interpreter): @@ -25,7 +24,7 @@ class PiecewiseInterpreter(Interpreter): module: GraphModule, enable_inductor: bool, compile_time_num_tokens: Union[int | torch.SymInt], - cuda_graph_batch_sizes: list[int], + capture_num_tokens: list[int], exclude_modules_id: list[int], graph_pool_handle: tuple[int, int], garbage_collect_values: bool = True, @@ -37,7 +36,7 @@ class PiecewiseInterpreter(Interpreter): self.fake_mode = detect_fake_mode() self.compile_time_num_tokens = compile_time_num_tokens - self.cuda_graph_batch_sizes = cuda_graph_batch_sizes + self.capture_num_tokens = capture_num_tokens self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id] self.graph_pool_handle = graph_pool_handle self.enable_inductor = enable_inductor @@ -86,7 +85,7 @@ class PiecewiseInterpreter(Interpreter): target, self.compile_time_num_tokens, runtime_num_tokens_idx, - self.cuda_graph_batch_sizes, + self.capture_num_tokens, self.graph_pool_handle, compile_fx(submod, args) if self.enable_inductor else submod, self.enable_inductor, @@ -120,7 +119,7 @@ class PiecewiseRunner(object): name: str, compile_time_num_tokens: Union[int | torch.SymInt], runtime_num_tokens_idx: tuple[int], - cuda_graph_batch_sizes: List[int], + capture_num_tokens: List[int], graph_pool_handle, default_callable: Callable, enable_inductor: bool, @@ -139,9 +138,9 @@ class PiecewiseRunner(object): self.entries: dict[int, Entry] = {} - for bs in cuda_graph_batch_sizes: - self.entries[bs] = Entry( - bs, + for num_tokens in capture_num_tokens: + self.entries[num_tokens] = Entry( + num_tokens, enable_inductor=self.enable_inductor, callable=default_callable, ) @@ -167,7 +166,7 @@ class PiecewiseRunner(object): if entry.cuda_graph is None: - if not get_enable_piecewise_cuda_graph_capture_flag(): + if not get_capture_piecewise_cuda_graph_flag(): return entry.callable(*args) if entry.warmup_count < 3: @@ -228,7 +227,7 @@ def piecewise_optimizer( example_inputs: List[torch.Tensor], enable_inductor: bool, input_num_tokens: Union[int | torch.SymInt], - cuda_graph_batch_sizes: Sequence[int], + capture_num_tokens: Sequence[int], graph_pool_handle: tuple[int, int], max_num_streams: int = 1, ) -> tuple[GraphModule, int]: @@ -269,7 +268,7 @@ def piecewise_optimizer( gm, enable_inductor, input_num_tokens, - cuda_graph_batch_sizes, + capture_num_tokens, exclude_modules_id, graph_pool_handle, max_num_streams=max_num_streams, diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index fef3de2a06..0166c455d2 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -1,3 +1,4 @@ +import contextlib from typing import Callable, List, Union import torch @@ -33,16 +34,26 @@ def is_call_function(node: Node, target: Union[List[Callable], Callable]): _enable_piecewise_cuda_graph_capture = False -def set_enable_piecewise_cuda_graph_capture_flag(enable: bool): +def set_capture_piecewise_cuda_graph_flag(enable: bool): global _enable_piecewise_cuda_graph_capture _enable_piecewise_cuda_graph_capture = enable -def get_enable_piecewise_cuda_graph_capture_flag() -> bool: +def get_capture_piecewise_cuda_graph_flag() -> bool: global _enable_piecewise_cuda_graph_capture return _enable_piecewise_cuda_graph_capture +@contextlib.contextmanager +def capture_piecewise_cuda_graph(enable: bool): + prev_enable = get_capture_piecewise_cuda_graph_flag() + set_capture_piecewise_cuda_graph_flag(enable) + try: + yield + finally: + set_capture_piecewise_cuda_graph_flag(prev_enable) + + def inplace_info(): inplace_map = { torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: { diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 5e001d9a48..ba71e4fbfe 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -57,12 +57,12 @@ def _register_fake(): #MNNVL Allreduce @torch.library.register_fake("trtllm::mnnvl_twoshot_allreduce") - def _(input, buffer, buffer_flags, wait_for_results): + def _(input, buffer, buffer_flags, buffer_size, wait_for_results): output = input.new_empty(input.shape) return output @torch.library.register_fake("trtllm::mnnvl_twoshot_rmsnorm") - def _(comm_buf, gamma, eps, residual, buffer_flags): + def _(comm_buf, gamma, eps, residual, buffer_flags, buffer_size): output = residual.new_empty(residual.shape) residual_out = residual.new_empty(residual.shape) return [output, residual_out] @@ -458,7 +458,7 @@ def _register_fake(): gemm2_output: torch.Tensor, fc2_expert_biases: torch.Tensor, unpermuted_final_scales: torch.Tensor, - expanded_source_row_to_expanded_dest_row: torch.Tensor, + unpermuted_row_to_permuted_row: torch.Tensor, expert_for_source_row: torch.Tensor, expert_first_token_offset_tensor: torch.Tensor, num_rows: torch.SymInt, @@ -501,7 +501,7 @@ def _register_fake(): shape[0] = sizes[local_rank] return input.new_empty(shape) - @torch.library.register_fake("trtllm::nvfp4_block_scale_interleave") + @torch.library.register_fake("trtllm::block_scale_interleave") def _(sf: torch.Tensor): rows = sf.shape[-2] cols = sf.shape[-1] @@ -511,7 +511,7 @@ def _register_fake(): return sf.new_empty((num_experts * expert_out_size, ), dtype=torch.uint8) - @torch.library.register_fake("trtllm::nvfp4_block_scale_interleave_reverse") + @torch.library.register_fake("trtllm::block_scale_interleave_reverse") def _(sf: torch.Tensor): return torch.empty_like(sf, dtype=torch.uint8) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index c98f9782bb..bd946343b0 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -8,6 +8,7 @@ from tensorrt_llm._utils import get_sm_version from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) +from ..modules.multi_stream_utils import do_multi_stream from ..utils import (fp4_scale_infer_shape, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2) @@ -22,9 +23,12 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict = dict() - tuning_config = TuningConfig(dynamic_tensor_specs=( - DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192), - lambda x: min(last_positive_power_of_2(x), 8192)), )) + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets(8192), + lambda x: min(last_positive_power_of_2(x), 8192)), ), + tune_max_num_tokens=8192, + ) def __init__( self, @@ -39,9 +43,11 @@ class MoERunner(TunableRunner): cluster_size: int, cluster_rank: int, use_deepseek_fp8_block_scale: bool, - use_w4a8_group_scaling: bool, + use_w4_group_scaling: bool, + use_int8_woq_per_channel: bool, use_mxfp8_act_scaling: bool, min_latency_mode: bool, + use_fused_finalize: bool, ): self.x_dtype = x_dtype self.weight_dtype = weight_dtype @@ -56,19 +62,23 @@ class MoERunner(TunableRunner): # The best tactic is estimated as if alltoall is disabled self.enable_alltoall = False self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale - self.use_w4a8_group_scaling = use_w4a8_group_scaling + self.use_w4_group_scaling = use_w4_group_scaling + self.use_int8_woq_per_channel = use_int8_woq_per_channel self.use_mxfp8_act_scaling = use_mxfp8_act_scaling self.min_latency_mode = min_latency_mode + self.use_fused_finalize = use_fused_finalize + instance_key = (x_dtype, weight_dtype, output_dtype, - use_deepseek_fp8_block_scale, use_w4a8_group_scaling, - use_mxfp8_act_scaling) + use_deepseek_fp8_block_scale, use_w4_group_scaling, + use_int8_woq_per_channel, use_mxfp8_act_scaling) if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[ instance_key] = torch.classes.trtllm.FusedMoeRunner( x_dtype, weight_dtype, output_dtype, - use_deepseek_fp8_block_scale, use_w4a8_group_scaling, - use_mxfp8_act_scaling) + use_deepseek_fp8_block_scale, use_w4_group_scaling, + use_int8_woq_per_channel, use_mxfp8_act_scaling, + use_fused_finalize) self.fused_moe_runner = MoERunner.runner_dict[instance_key] def get_valid_tactics( @@ -106,15 +116,6 @@ class MoERunner(TunableRunner): do_preparation, ) - @classmethod - @lru_cache(maxsize=None) - def refine_tuning_config(cls, tune_max_num_tokens: int): - cls.tuning_config = TuningConfig( - dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, get_last_power_of_2_num_tokens_buckets( - tune_max_num_tokens), lambda x: min( - last_positive_power_of_2(x), tune_max_num_tokens)), )) - @torch.library.custom_op("trtllm::fused_moe", mutates_args=()) def fused_moe( @@ -128,6 +129,10 @@ def fused_moe( output_dtype: torch.dtype, quant_scales: List[torch.Tensor], input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, @@ -136,16 +141,17 @@ def fused_moe( cluster_rank: int = 0, enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, - use_w4a8_group_scaling: bool = False, + use_w4_group_scaling: bool = False, + use_int8_woq_per_channel: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, + use_fused_finalize: bool = True, tune_max_num_tokens: int = 8192, tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, ) -> List[torch.Tensor]: tuner = AutoTuner.get() - MoERunner.refine_tuning_config(tune_max_num_tokens) # Only the non-alltoall case is considered for profiling in the warmup phase. # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall. @@ -172,11 +178,15 @@ def fused_moe( cluster_size=cluster_size, cluster_rank=cluster_rank, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - use_w4a8_group_scaling=use_w4a8_group_scaling, + use_w4_group_scaling=use_w4_group_scaling, + use_int8_woq_per_channel=use_int8_woq_per_channel, use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, ) + MoERunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens + _, gemm_tactic_1 = tuner.choose_one( "trtllm::fused_moe::gemm1", [moe_runner], @@ -210,6 +220,10 @@ def fused_moe( fc2_expert_biases, quant_scales, input_sf, + swizzled_input_sf, + swiglu_alpha, + swiglu_beta, + swiglu_limit, tp_size, tp_rank, ep_size, @@ -236,6 +250,10 @@ def _( output_dtype: torch.dtype, quant_scales: List[torch.Tensor], input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, @@ -244,13 +262,20 @@ def _( cluster_rank: int = 0, enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, - use_w4a8_group_scaling: bool = False, + use_w4_group_scaling: bool = False, + use_int8_woq_per_channel: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, + use_fused_finalize: bool = True, tune_max_num_tokens: int = 8192, ): seq_len = input.shape[0] - hidden_size = fc2_expert_weights.shape[1] + if use_int8_woq_per_channel: + # Note: The weight shape for INT8 weight only quantization is different, i.e., + # fc2_expert_weights: [num_experts, inter_size, hidden_size] + hidden_size = fc2_expert_weights.shape[2] + else: + hidden_size = fc2_expert_weights.shape[1] if min_latency_mode: num_experts_on_rank = fc2_expert_weights.shape[0] @@ -901,6 +926,8 @@ def get_stream(stream_id: int): @torch.library.custom_op("trtllm::set_stream", mutates_args=()) def set_stream(stream_id: int) -> None: + if not do_multi_stream(): + return stream = get_stream(stream_id) assert stream is not None torch.cuda.set_stream(stream) @@ -908,18 +935,24 @@ def set_stream(stream_id: int) -> None: @torch.library.custom_op("trtllm::record_event", mutates_args=()) def record_event(event_idx: int) -> None: + if not do_multi_stream(): + return event = get_event(event_idx) event.record() @torch.library.custom_op("trtllm::wait_event", mutates_args=()) def wait_event(event_idx: int) -> None: + if not do_multi_stream(): + return event = get_event(event_idx) event.wait() @torch.library.custom_op("trtllm::record_stream", mutates_args=()) def record_stream(tensor: torch.Tensor, stream_id: int) -> None: + if not do_multi_stream(): + return stream = get_stream(stream_id) assert stream is not None tensor.record_stream(stream) diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 622fa12c51..2bb780f6ef 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -563,3 +563,672 @@ def _( return hidden_states.new_empty((num_tokens, hidden_size), dtype=torch.bfloat16) + + +@dataclass(frozen=True) +class MxE4m3MxE2m1BlockScaleMoEInputs: + routing_logits: torch.Tensor + routing_bias: Optional[torch.Tensor] + hidden_states: torch.Tensor + hidden_states_scale: torch.Tensor + gemm1_weights: torch.Tensor + gemm1_weights_scale: torch.Tensor + gemm1_bias: Optional[torch.Tensor] + gemm1_alpha: Optional[torch.Tensor] + gemm1_beta: Optional[torch.Tensor] + gemm1_clamp_limit: Optional[torch.Tensor] + gemm2_weights: torch.Tensor + gemm2_weights_scale: torch.Tensor + gemm2_bias: Optional[torch.Tensor] + + +class MxE4m3MxE2m1BlockScaleMoERunner(TunableRunner): + + runner_dict = dict() + tuning_config = None + + def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, + hidden_size_output: int, local_expert_offset: int, + local_num_experts: int, routed_scaling_factor: Optional[float], + tile_tokens_dim: int, routing_method_type: int, act_type: int): + + self.num_experts = num_experts + self.top_k = top_k + self.n_group = n_group + self.topk_group = topk_group + self.intermediate_size = intermediate_size + self.hidden_size_output = hidden_size_output + self.local_expert_offset = local_expert_offset + self.local_num_experts = local_num_experts + self.routed_scaling_factor = routed_scaling_factor + self.tile_tokens_dim = tile_tokens_dim + self.routing_method_type = routing_method_type + self.act_type = act_type + + MxE4m3MxE2m1BlockScaleMoERunner.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config( + ) + + instance_key = ( + self.top_k, + self.intermediate_size, + self.hidden_size_output, + self.local_num_experts, + self.tile_tokens_dim, + self.act_type, + ) + + if instance_key not in MxE4m3MxE2m1BlockScaleMoERunner.runner_dict: + MxE4m3MxE2m1BlockScaleMoERunner.runner_dict[ + instance_key] = torch.classes.trtllm.MxE4m3MxE2m1BlockScaleMoERunner( + self.tile_tokens_dim, self.act_type, True) + + self.kernel_runner = MxE4m3MxE2m1BlockScaleMoERunner.runner_dict[ + instance_key] + + # The hash is used by the autotuner to get the cache key, so we hash on members + # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing + # type does not matter + def __hash__(self): + return hash(( + self.top_k, + self.intermediate_size, + self.hidden_size_output, + self.local_num_experts, + self.tile_tokens_dim, + self.act_type, + )) + + # __eq__ and __hash__ must agree + def __eq__(self, other): + if not isinstance(other, MxE4m3MxE2m1BlockScaleMoERunner): + return False + + return (self.top_k == other.top_k + and self.intermediate_size == other.intermediate_size + and self.hidden_size_output == other.hidden_size_output + and self.local_num_experts == other.local_num_experts + and self.tile_tokens_dim == other.tile_tokens_dim + and self.act_type == other.act_type) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + ) -> torch.Tensor: + + args = MxE4m3MxE2m1BlockScaleMoEInputs(*inputs) + + return self.kernel_runner.run_moe( + args.routing_logits, args.routing_bias, args.hidden_states, + args.hidden_states_scale, args.gemm1_weights, + args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha, + args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights, + args.gemm2_weights_scale, args.gemm2_bias, None, None, None, + self.num_experts, self.top_k, self.n_group, self.topk_group, + self.intermediate_size, self.hidden_size_output, + self.local_expert_offset, self.local_num_experts, + self.routed_scaling_factor, self.routing_method_type, tactic) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + + args = MxE4m3MxE2m1BlockScaleMoEInputs(*inputs) + + num_tokens = args.hidden_states.shape[0] + hidden_size = args.hidden_states.shape[1] + + tactics = self.kernel_runner.get_valid_configs(self.top_k, hidden_size, + self.intermediate_size, + self.local_num_experts, + num_tokens) + + return tactics + + @classmethod + def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + HIDDEN_STATES_IDX = 2 + TUNED_DIM = 0 + + m_values = get_last_power_of_2_num_tokens_buckets(4096) + round_rule = lambda x: min(last_positive_power_of_2(x), 4096) + + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, + round_rule), ) + + return specs + + @classmethod + def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: + + def _constrain_hidden_states_scale(shapes: Tuple[torch.Size]) -> int: + # hidden_states dim 0 and dim 1 + num_tokens = shapes[2][0] + hidden_size = shapes[2][1] + + SF_BLOCK_SIZE = 32 + + # Linear fp4 sf layout is just rows * columns + size = num_tokens * (hidden_size // SF_BLOCK_SIZE) + + return size + + def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: + # hidden_states dim 0 and dim 1 + num_tokens = shapes[2][0] + + return num_tokens + + HIDDEN_STATE_SCALE_IDX = 3 + CONSTRAINED_HS_DIM = 0 + + constraint_hidden_states_scale = ConstraintSpec( + HIDDEN_STATE_SCALE_IDX, CONSTRAINED_HS_DIM, + _constrain_hidden_states_scale) + + ROUTER_LOGITS_IDX = 0 + CONSTRAINED_RL_DIM = 0 + + constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, + CONSTRAINED_RL_DIM, + _constrain_routing_logits) + + constraint_specs_tuple = (constraint_hidden_states_scale, + constraint_routing_logits) + + return constraint_specs_tuple + + @classmethod + @lru_cache(maxsize=None) + def get_tuning_config(cls) -> TuningConfig: + + dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + constraint_specs = cls.get_constraint_specs() + + tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs) + + return tuning_config + + +@torch.library.custom_op("trtllm::mxe4m3_mxe2m1_block_scale_moe_runner", + mutates_args=()) +def mxe4m3_mxe2m1_block_scale_moe_runner( + routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, + gemm1_bias: Optional[torch.Tensor], gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, gemm2_bias: Optional[torch.Tensor], + num_experts: int, top_k: int, n_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, + hidden_size_output: int, local_expert_offset: int, + local_num_experts: int, routed_scaling_factor: Optional[float], + tile_tokens_dim: int, routing_method_type: int, + act_type: int) -> torch.Tensor: + + tuner = AutoTuner.get() + + kernel_runner = MxE4m3MxE2m1BlockScaleMoERunner( + num_experts, top_k, n_group, topk_group, intermediate_size, + hidden_size_output, local_expert_offset, local_num_experts, + routed_scaling_factor, tile_tokens_dim, routing_method_type, act_type) + + input_tensors = [ + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm1_bias, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + gemm2_weights, + gemm2_weights_scale, + gemm2_bias, + ] + + _, best_tactic = tuner.choose_one( + "trtllm::mxe4m3_mxe2m1_block_scale_moe_runner", + [kernel_runner], + kernel_runner.tuning_config, + input_tensors, + ) + + return kernel_runner(input_tensors, tactic=best_tactic) + + +@dataclass(frozen=True) +class E4m3MxE2m1BlockScaleMoEInputs: + routing_logits: torch.Tensor + routing_bias: Optional[torch.Tensor] + hidden_states: torch.Tensor + gemm1_weights: torch.Tensor + gemm1_weights_scale: torch.Tensor + gemm1_bias: Optional[torch.Tensor] + gemm1_alpha: Optional[torch.Tensor] + gemm1_beta: Optional[torch.Tensor] + gemm1_clamp_limit: Optional[torch.Tensor] + gemm2_weights: torch.Tensor + gemm2_weights_scale: torch.Tensor + gemm2_bias: Optional[torch.Tensor] + output1_scale_scalar: torch.Tensor + output1_scale_gate_scalar: torch.Tensor + output2_scale_scalar: torch.Tensor + + +class E4m3MxE2m1BlockScaleMoERunner(TunableRunner): + + runner_dict = dict() + tuning_config = None + + def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, + local_expert_offset: int, local_num_experts: int, + routed_scaling_factor: Optional[float], tile_tokens_dim: int, + routing_method_type: int, act_type: int): + + self.num_experts = num_experts + self.top_k = top_k + self.n_group = n_group + self.topk_group = topk_group + self.intermediate_size = intermediate_size + self.local_expert_offset = local_expert_offset + self.local_num_experts = local_num_experts + self.routed_scaling_factor = routed_scaling_factor + self.tile_tokens_dim = tile_tokens_dim + self.routing_method_type = routing_method_type + self.act_type = act_type + + E4m3MxE2m1BlockScaleMoERunner.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config( + ) + + instance_key = ( + self.top_k, + self.intermediate_size, + self.local_num_experts, + self.tile_tokens_dim, + self.act_type, + ) + + if instance_key not in E4m3MxE2m1BlockScaleMoERunner.runner_dict: + E4m3MxE2m1BlockScaleMoERunner.runner_dict[ + instance_key] = torch.classes.trtllm.MxE4m3MxE2m1BlockScaleMoERunner( + self.tile_tokens_dim, self.act_type, False) + + self.kernel_runner = E4m3MxE2m1BlockScaleMoERunner.runner_dict[ + instance_key] + + # The hash is used by the autotuner to get the cache key, so we hash on members + # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing + # type does not matter + def __hash__(self): + return hash(( + self.top_k, + self.intermediate_size, + self.local_num_experts, + self.tile_tokens_dim, + self.act_type, + )) + + # __eq__ and __hash__ must agree + def __eq__(self, other): + if not isinstance(other, E4m3MxE2m1BlockScaleMoERunner): + return False + + return (self.top_k == other.top_k + and self.intermediate_size == other.intermediate_size + and self.local_num_experts == other.local_num_experts + and self.tile_tokens_dim == other.tile_tokens_dim + and self.act_type == other.act_type) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + ) -> torch.Tensor: + + args = E4m3MxE2m1BlockScaleMoEInputs(*inputs) + + return self.kernel_runner.run_moe( + args.routing_logits, args.routing_bias, args.hidden_states, None, + args.gemm1_weights, args.gemm1_weights_scale, args.gemm1_bias, + args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, + args.gemm2_weights, args.gemm2_weights_scale, args.gemm2_bias, + args.output1_scale_scalar, args.output1_scale_gate_scalar, + args.output2_scale_scalar, self.num_experts, self.top_k, + self.n_group, self.topk_group, self.intermediate_size, None, + self.local_expert_offset, self.local_num_experts, + self.routed_scaling_factor, self.routing_method_type, tactic) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + + args = E4m3MxE2m1BlockScaleMoEInputs(*inputs) + + num_tokens = args.hidden_states.shape[0] + hidden_size = args.hidden_states.shape[1] + + tactics = self.kernel_runner.get_valid_configs(self.top_k, hidden_size, + self.intermediate_size, + self.local_num_experts, + num_tokens) + + return tactics + + @classmethod + def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + HIDDEN_STATES_IDX = 2 + TUNED_DIM = 0 + + m_values = get_last_power_of_2_num_tokens_buckets(4096) + round_rule = lambda x: min(last_positive_power_of_2(x), 4096) + + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, + round_rule), ) + + return specs + + @classmethod + def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: + + def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: + # hidden_states dim 0 and dim 1 + num_tokens = shapes[2][0] + + return num_tokens + + ROUTER_LOGITS_IDX = 0 + CONSTRAINED_RL_DIM = 0 + + constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, + CONSTRAINED_RL_DIM, + _constrain_routing_logits) + + constraint_specs_tuple = (constraint_routing_logits, ) + + return constraint_specs_tuple + + @classmethod + @lru_cache(maxsize=None) + def get_tuning_config(cls) -> TuningConfig: + + dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + constraint_specs = cls.get_constraint_specs() + + tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs) + + return tuning_config + + +@torch.library.custom_op("trtllm::e4m3_mxe2m1_block_scale_moe_runner", + mutates_args=()) +def e4m3_mxe2m1_block_scale_moe_runner( + routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, gemm1_bias: Optional[torch.Tensor], + gemm1_alpha: Optional[torch.Tensor], gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, gemm2_bias: Optional[torch.Tensor], + output1_scale_scalar: torch.Tensor, + output1_scale_gate_scalar: torch.Tensor, + output2_scale_scalar: torch.Tensor, num_experts: int, top_k: int, + n_group: Optional[int], topk_group: Optional[int], + intermediate_size: int, local_expert_offset: int, + local_num_experts: int, routed_scaling_factor: Optional[float], + tile_tokens_dim: int, routing_method_type: int, + act_type: int) -> torch.Tensor: + + tuner = AutoTuner.get() + + kernel_runner = E4m3MxE2m1BlockScaleMoERunner( + num_experts, top_k, n_group, topk_group, intermediate_size, + local_expert_offset, local_num_experts, routed_scaling_factor, + tile_tokens_dim, routing_method_type, act_type) + + input_tensors = [ + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm1_weights_scale, + gemm1_bias, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + gemm2_weights, + gemm2_weights_scale, + gemm2_bias, + output1_scale_scalar, + output1_scale_gate_scalar, + output2_scale_scalar, + ] + + _, best_tactic = tuner.choose_one( + "trtllm::e4m3_mxe2m1_block_scale_moe_runner", + [kernel_runner], + kernel_runner.tuning_config, + input_tensors, + ) + + return kernel_runner(input_tensors, tactic=best_tactic) + + +@dataclass(frozen=True) +class Bf16MxE2m1BlockScaleMoEInputs: + routing_logits: torch.Tensor + routing_bias: Optional[torch.Tensor] + hidden_states: torch.Tensor + gemm1_weights: torch.Tensor + gemm1_weights_scale: torch.Tensor + gemm1_bias: Optional[torch.Tensor] + gemm1_alpha: Optional[torch.Tensor] + gemm1_beta: Optional[torch.Tensor] + gemm1_clamp_limit: Optional[torch.Tensor] + gemm2_weights: torch.Tensor + gemm2_weights_scale: torch.Tensor + gemm2_bias: Optional[torch.Tensor] + + +class Bf16MxE2m1BlockScaleMoERunner(TunableRunner): + + runner_dict = dict() + tuning_config = None + + def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, + local_expert_offset: int, local_num_experts: int, + routed_scaling_factor: Optional[float], tile_tokens_dim: int, + routing_method_type: int, act_type: int): + + self.num_experts = num_experts + self.top_k = top_k + self.n_group = n_group + self.topk_group = topk_group + self.intermediate_size = intermediate_size + self.local_expert_offset = local_expert_offset + self.local_num_experts = local_num_experts + self.routed_scaling_factor = routed_scaling_factor + self.tile_tokens_dim = tile_tokens_dim + self.routing_method_type = routing_method_type + self.act_type = act_type + + Bf16MxE2m1BlockScaleMoERunner.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config( + ) + + instance_key = ( + self.top_k, + self.intermediate_size, + self.local_num_experts, + self.tile_tokens_dim, + self.act_type, + ) + + if instance_key not in Bf16MxE2m1BlockScaleMoERunner.runner_dict: + Bf16MxE2m1BlockScaleMoERunner.runner_dict[ + instance_key] = torch.classes.trtllm.Bf16MxE2m1BlockScaleMoERunner( + self.tile_tokens_dim, self.act_type) + + self.kernel_runner = Bf16MxE2m1BlockScaleMoERunner.runner_dict[ + instance_key] + + # The hash is used by the autotuner to get the cache key, so we hash on members + # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing + # type does not matter + def __hash__(self): + return hash(( + self.top_k, + self.intermediate_size, + self.local_num_experts, + self.tile_tokens_dim, + self.act_type, + )) + + # __eq__ and __hash__ must agree + def __eq__(self, other): + if not isinstance(other, Bf16MxE2m1BlockScaleMoERunner): + return False + + return (self.top_k == other.top_k + and self.intermediate_size == other.intermediate_size + and self.local_num_experts == other.local_num_experts + and self.tile_tokens_dim == other.tile_tokens_dim + and self.act_type == other.act_type) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + ) -> torch.Tensor: + + args = Bf16MxE2m1BlockScaleMoEInputs(*inputs) + + return self.kernel_runner.run_moe( + args.routing_logits, args.routing_bias, args.hidden_states, + args.gemm1_weights, args.gemm1_weights_scale, args.gemm1_bias, + args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, + args.gemm2_weights, args.gemm2_weights_scale, args.gemm2_bias, + self.num_experts, self.top_k, self.n_group, self.topk_group, + self.intermediate_size, self.local_expert_offset, + self.local_num_experts, self.routed_scaling_factor, + self.routing_method_type, tactic) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + + args = Bf16MxE2m1BlockScaleMoEInputs(*inputs) + + num_tokens = args.hidden_states.shape[0] + hidden_size = args.hidden_states.shape[1] + + tactics = self.kernel_runner.get_valid_configs(self.top_k, hidden_size, + self.intermediate_size, + self.local_num_experts, + num_tokens) + + return tactics + + @classmethod + def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + HIDDEN_STATES_IDX = 2 + TUNED_DIM = 0 + + m_values = get_last_power_of_2_num_tokens_buckets(4096) + round_rule = lambda x: min(last_positive_power_of_2(x), 4096) + + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, + round_rule), ) + + return specs + + @classmethod + def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: + + def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: + # hidden_states dim 0 and dim 1 + num_tokens = shapes[2][0] + + return num_tokens + + ROUTER_LOGITS_IDX = 0 + CONSTRAINED_DIM = 0 + + constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, + CONSTRAINED_DIM, + _constrain_routing_logits) + + constraint_specs_tuple = (constraint_routing_logits, ) + + return constraint_specs_tuple + + @classmethod + @lru_cache(maxsize=None) + def get_tuning_config(cls) -> TuningConfig: + + dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + constraint_specs = cls.get_constraint_specs() + + tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs) + + return tuning_config + + +@torch.library.custom_op("trtllm::bf16_mxe2m1_block_scale_moe_runner", + mutates_args=()) +def bf16_mxe2m1_block_scale_moe_runner( + routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, gemm1_bias: Optional[torch.Tensor], + gemm1_alpha: Optional[torch.Tensor], gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, gemm2_bias: Optional[torch.Tensor], + num_experts: int, top_k: int, n_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, + local_expert_offset: int, local_num_experts: int, + routed_scaling_factor: Optional[float], tile_tokens_dim: int, + routing_method_type: int, act_type: int) -> torch.Tensor: + + tuner = AutoTuner.get() + + kernel_runner = Bf16MxE2m1BlockScaleMoERunner( + num_experts, top_k, n_group, topk_group, intermediate_size, + local_expert_offset, local_num_experts, routed_scaling_factor, + tile_tokens_dim, routing_method_type, act_type) + + input_tensors = [ + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm1_weights_scale, + gemm1_bias, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + gemm2_weights, + gemm2_weights_scale, + gemm2_bias, + ] + + _, best_tactic = tuner.choose_one( + "trtllm::bf16_mxe2m1_block_scale_moe_runner", + [kernel_runner], + kernel_runner.tuning_config, + input_tensors, + ) + + return kernel_runner(input_tensors, tactic=best_tactic) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index ba713a7d56..74ac9590a3 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1,5 +1,7 @@ +import logging import math import os +import platform import threading from typing import List, Optional, Tuple, Union @@ -10,10 +12,12 @@ from tensorrt_llm._utils import mpi_barrier from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy, MoEAllReduceParams) +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.plugin.plugin import CustomAllReduceHelper _thread_local = threading.local() +logger = logging.getLogger(__name__) def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: @@ -63,11 +67,10 @@ def get_allreduce_mnnvl_workspace( if mapping not in allreduce_mnnvl_workspaces: # buffer shape: [3, 2, buffer_tokens, hidden_dim] stride = 3 * 2 * dtype.itemsize - # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 - # max_num_elements must be a multiple of 286720 - lcm_hidden_dim = 286720 + # Max hidden_size_to_support + max_hidden_dim = 16384 buffer_size_in_bytes = math.ceil( - 12_000_000 / (lcm_hidden_dim * stride)) * (lcm_hidden_dim * stride) + 12_000_000 / (max_hidden_dim * stride)) * (max_hidden_dim * stride) max_num_elements = buffer_size_in_bytes // stride mcast_buffer = McastGPUBuffer( @@ -75,7 +78,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_size, mapping.tp_rank, torch.device("cuda", mapping.local_rank), - mapping.is_multi_node() or force_mn, + True, # mnNvlink ) buffer = mcast_buffer.get_uc_buffer(mapping.tp_rank, @@ -88,8 +91,8 @@ def get_allreduce_mnnvl_workspace( # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer - # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_to_clear,atomic access counter] - buffer_flags = torch.tensor([0, 2, max_num_elements, 0, 0], + # [Buffer_ptr, Clear_ptr, num_tokens_to_clear,atomic access counter] + buffer_flags = torch.tensor([0, 2, 0, 0], dtype=torch.uint32, device=torch.device("cuda", mapping.local_rank)) @@ -291,6 +294,8 @@ class MNNVLAllReduce(nn.Module): for certain operations when using NVLink for multi-node communication. """ + SUPPORTED_FUSION_HIDDEN_DIMS = [2048, 2880, 4096, 5120, 7168, 8192] + def __init__(self, mapping: Mapping, dtype: torch.dtype): super().__init__() self.mapping = mapping @@ -307,6 +312,16 @@ class MNNVLAllReduce(nn.Module): def get_supported_dtypes(): return (torch.float16, torch.bfloat16, torch.float32) + @staticmethod + def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool: + from tensorrt_llm._mnnvl_utils import MnnvlMemory + + arch = platform.machine().lower() + is_on_aarch64 = "aarch64" in arch + return (dtype in MNNVLAllReduce.get_supported_dtypes() + and not mapping.has_cp() and mapping.is_multi_node() + and MnnvlMemory.supports_mnnvl() and is_on_aarch64) + def forward( self, input: torch.Tensor, @@ -321,40 +336,62 @@ class MNNVLAllReduce(nn.Module): Returns: Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Reduced tensor(s) """ - if input.numel() > self.max_num_elements_mnnvl: - return None fusion_op = all_reduce_params.fusion_op - shape = input.shape + input = input.view(-1, shape[-1]) + (num_tokens, hidden_dim) = input.shape - if self.buffer_mnnvl.shape[-1] % shape[-1] != 0: + # Slice the buffer according to the hidden size, need to pass this numel as the new buffer size + max_num_tokens = self.max_num_elements_mnnvl // hidden_dim + num_elements_in_use = max_num_tokens * hidden_dim + if num_tokens > max_num_tokens: + logger.debug( + f"MNNVL AllReduce can't be enabled due to {num_tokens=} larger than {max_num_tokens=}." + ) + return None + + # This should not happen but leave this check for future code changes + if num_elements_in_use > self.max_num_elements_mnnvl: + logger.debug( + f"MNNVL AllReduce can't be enabled due to {num_elements_in_use=} larger than {self.max_num_elements_mnnvl=}." + ) return None - input = input.view(-1, shape[-1]) output = torch.empty_like(input) - buffer_mnnvl = self.buffer_mnnvl.view(3, 2, -1, shape[-1]) + buffer_mnnvl = self.buffer_mnnvl.view(-1)[:(3 * 2 * + num_elements_in_use)].view( + 3, 2, -1, hidden_dim) if fusion_op == AllReduceFusionOp.NONE: output = torch.ops.trtllm.mnnvl_twoshot_allreduce( input, buffer_mnnvl, self.buffer_flags_mnnvl, + num_elements_in_use, True, ) return output.view(shape) - elif fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM: + # Fallback to use other allreduce if hidden_size is not supported + elif (fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM + and hidden_dim in MNNVLAllReduce.SUPPORTED_FUSION_HIDDEN_DIMS): torch.ops.trtllm.mnnvl_twoshot_allreduce( input, buffer_mnnvl, self.buffer_flags_mnnvl, + num_elements_in_use, False, ) residual_in = all_reduce_params.residual output, residual_out = torch.ops.trtllm.mnnvl_twoshot_rmsnorm( - buffer_mnnvl, all_reduce_params.norm_weight, - all_reduce_params.eps, residual_in, self.buffer_flags_mnnvl) + buffer_mnnvl, + all_reduce_params.norm_weight, + all_reduce_params.eps, + residual_in, + self.buffer_flags_mnnvl, + num_elements_in_use, + ) return output.view(shape), residual_out.view(shape) return None @@ -418,11 +455,24 @@ class AllReduce(nn.Module): self.workspace = get_allreduce_workspace(self.mapping) # Initialize MNNVL AllReduce if needed - if self.strategy == AllReduceStrategy.MNNVL and ( - dtype and dtype in MNNVLAllReduce.get_supported_dtypes() - ) and (not self.mapping.has_cp()): - self.mnnvl_allreduce = MNNVLAllReduce(self.mapping, - dtype) if dtype else None + if self.strategy == AllReduceStrategy.MNNVL: + if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): + try: + self.mnnvl_allreduce = MNNVLAllReduce( + self.mapping, dtype) if dtype else None + if self.mnnvl_allreduce: + logger.debug(f"MNNVLAllReduce is enabled") + else: + logger.debug(f"MNNVLAllReduce is disabled") + except Exception as e: + logger.debug( + f"MNNVL AllReduce can't be enabled due to {e}.") + self.mnnvl_allreduce = None + else: + logger.debug( + f"MNNVLAllReduce can't be enabled due to failing the is_mnnvl check." + ) + self.mnnvl_allreduce = None def forward( self, @@ -458,6 +508,8 @@ class AllReduce(nn.Module): == False): return input + input = input.contiguous() # Underlying op requires contiguous input + allreduce_strategy = self.strategy if all_reduce_params is None: all_reduce_params = AllReduceParams() diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 232d2ccecd..7e310f934a 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -1,4 +1,5 @@ import json +import os from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Generic, List, Optional, TypeVar @@ -8,7 +9,7 @@ import transformers from tensorrt_llm import logger from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid -from tensorrt_llm._utils import torch_dtype_to_binding +from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding from tensorrt_llm.bindings import LayerType as LayerTypeCpp from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.logger import logger @@ -82,6 +83,9 @@ class ModelConfig(Generic[TConfig]): attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM + # IF true, disables FC2+finalize fusion in CUTLASS MoE backend + moe_disable_finalize_fusion: bool = False + allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO # If true, enable min-latency mode. Currently only used for Llama4. @@ -92,6 +96,9 @@ class ModelConfig(Generic[TConfig]): force_dynamic_quantization: bool = False + # If true, use torch.compile for embedding layers. + enable_torch_compile_for_embedding = False + extra_attrs: Dict = field(default_factory=dict, repr=False, init=False) _frozen: bool = field(default=False, init=False, repr=False) @@ -126,7 +133,8 @@ class ModelConfig(Generic[TConfig]): "ONESHOT": AllReduceStrategy.ONESHOT, "TWOSHOT": AllReduceStrategy.TWOSHOT, "LOWPRECISION": AllReduceStrategy.LOWPRECISION, - "MNNVL": AllReduceStrategy.MNNVL + "MNNVL": AllReduceStrategy.MNNVL, + "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC } key = strategy.upper() return maps[key] if key in maps else AllReduceStrategy.AUTO @@ -135,6 +143,14 @@ class ModelConfig(Generic[TConfig]): self.allreduce_strategy = get_all_reduce_strategy( self.allreduce_strategy) + @property + def torch_dtype(self) -> torch.dtype: + """Get the torch dtype of the model.""" + # TODO: this is an assumption that a HF model is always in bfloat16 + # We should figure out a better way to handle this if other models + # start to not report dtype. + return self.pretrained_config.torch_dtype or torch.bfloat16 + @property def fuse_pos_embd(self): if self.attn_backend == 'TRTLLM': @@ -146,8 +162,9 @@ class ModelConfig(Generic[TConfig]): @property def enable_flash_mla(self): if self.attn_backend == 'TRTLLM': - if hasattr(self.pretrained_config, "kv_lora_rank") and hasattr( - self.pretrained_config, "qk_rope_head_dim"): + if getattr(self.pretrained_config, + "kv_lora_rank", None) and getattr( + self.pretrained_config, "qk_rope_head_dim", None): head_dim = self.pretrained_config.kv_lora_rank + self.pretrained_config.qk_rope_head_dim if head_dim == 576 and torch.cuda.get_device_capability() == ( 9, 0): @@ -177,6 +194,158 @@ class ModelConfig(Generic[TConfig]): # TODO: should be 'not model_type == ModelType.ENCODER_ONLY' # once ModelType is used in pytorch flow. + @staticmethod + def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend): + quant_config = QuantConfig() + layer_quant_config = None + + with open(quant_config_file) as f: + quant_config_dict = json.load(f) + + json_quant_configs = quant_config_dict['quantization'] + + quant_config.quant_algo = json_quant_configs.get('quant_algo', None) + # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES + if quant_config.quant_algo == "fp8_pb_wo": + quant_config.quant_algo = 'FP8_BLOCK_SCALES' + quant_config.kv_cache_quant_algo = json_quant_configs.get( + 'kv_cache_quant_algo', None) + quant_config.group_size = json_quant_configs.get('group_size', None) + quant_config.exclude_modules = json_quant_configs.get( + 'exclude_modules', None) + + if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: + mixed_quant_config_file = model_dir / 'quant_cfg.json' + with open(mixed_quant_config_file) as fm: + mixed_quant_configs = json.load(fm) + # kv_cache_quant_algo is global regardless of MIXED_PRECISION + kv_cache_quant_algo = mixed_quant_configs['kv_cache_quant_algo'] + mixed_quant_configs = mixed_quant_configs['quantized_layers'] + if kv_cache_quant_algo is not None and quant_config.kv_cache_quant_algo is not None: + if kv_cache_quant_algo != quant_config.kv_cache_quant_algo: + raise RuntimeError( + f"The kvcache config in 'quant_cfg.json', {kv_cache_quant_algo}," + f"is different from 'hf_quant_config.json', {quant_config.kv_cache_quant_algo}!" + ) + kv_cache_quant_algo = kv_cache_quant_algo or quant_config.kv_cache_quant_algo + + for layer in mixed_quant_configs: + config = QuantConfig() + config.kv_cache_quant_algo = kv_cache_quant_algo + config.quant_algo = mixed_quant_configs[layer]['quant_algo'] + config.group_size = mixed_quant_configs[layer].get( + 'group_size', None) + mixed_quant_configs[layer] = config + layer_quant_config = mixed_quant_configs + elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: + if quant_config.group_size is None: + quant_config.group_size = 128 + + if moe_backend == 'TRTLLM' and quant_config.quant_algo == "FP8_BLOCK_SCALES" and quant_config.exclude_modules is None: + quant_config.exclude_modules = [ + "*kv_b_proj*", "*k_b_proj*", "*eh_proj" + ] + return quant_config, layer_quant_config + + @staticmethod + def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): + quant_algo = ModelConfig.override_quant_algo() + if quant_algo is None and not is_dynamic_quant: + if get_sm_version() >= 100: + if moe_backend == 'TRITON': + return QuantAlgo.W4A8_MXFP4_FP8 + else: + return QuantAlgo.W4A8_MXFP4_MXFP8 + else: + return QuantAlgo.W4A16_MXFP4 + else: + return quant_algo + + @staticmethod + def load_hf_quant_config(hf_quant_config, moe_backend): + quant_config = QuantConfig() + layer_quant_config = None + + # DeepSeek V3 FP8 ckpt + if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get( + "weight_block_size", []): + quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + if moe_backend == 'TRTLLM': + # TODO: This is a hack. Remove after fp8 bmm is integrated. + quant_config.exclude_modules = [ + "*kv_b_proj*", "*k_b_proj*", "*eh_proj" + ] + else: + quant_config.exclude_modules = ["*eh_proj"] + + block_size = hf_quant_config.get("weight_block_size", []) + assert tuple(block_size) == ( + 128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" + quant_config.group_size = block_size[0] + # MXFP4 checkpoints. + elif hf_quant_config.get("quant_method") == "mxfp4": + quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( + moe_backend) + quant_config.group_size = 32 + quant_config.exclude_modules = [ + 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', + 'embedding', 'unembedding' + ] + + return quant_config, layer_quant_config + + @staticmethod + def load_quant_config_from_dtypes_json(dtypes_json_file, moe_backend: str): + quant_config = QuantConfig() + layer_quant_config = None + + exclude_modules = set() + has_mxfp4 = False + is_dynamic_quant = False + with open(dtypes_json_file) as f: + dtypes_json = json.load(f) + for layer, dtype in dtypes_json.items(): + if layer.endswith("weight"): + if dtype == "BF16" or dtype == "FP16": + names = layer.split(".") + exclude_modules.add('.'.join(names[:-1])) + elif dtype == "MXFP4": + # This is the path for the fp8 checkpoint which requires dynamic quantization. + is_dynamic_quant = True + has_mxfp4 = True + elif layer.endswith("weight.blocks"): + scale_name = layer.replace("weight.blocks", "weight.scales") + scale_dtype = dtypes_json.get(scale_name, None) + assert scale_dtype == "UE8" + is_dynamic_quant = False + has_mxfp4 = True + + if has_mxfp4: + quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( + moe_backend, is_dynamic_quant) + quant_config.group_size = 32 + quant_config.exclude_modules = list(exclude_modules) + logger.info(f"Setting quant_config: {quant_config}") + + return quant_config, layer_quant_config + + @staticmethod + def override_quant_algo(): + new_algo = os.environ.get("OVERRIDE_QUANT_ALGO", None) + supported_algos = { + "W4A16_MXFP4": QuantAlgo.W4A16_MXFP4, + "W4A8_MXFP4_MXFP8": QuantAlgo.W4A8_MXFP4_MXFP8, + "W4A8_MXFP4_FP8": QuantAlgo.W4A8_MXFP4_FP8, + } + if new_algo is not None: + if new_algo.upper() in supported_algos: + return supported_algos[new_algo.upper()] + else: + logger.warning( + f"Unsupported quant algo: {new_algo}, supported algos: {supported_algos.keys()}" + ) + return None + @classmethod def from_pretrained(cls, checkpoint_dir: str, @@ -194,82 +363,20 @@ class ModelConfig(Generic[TConfig]): 'config.json')).parent quant_config = QuantConfig() layer_quant_config = None + moe_backend = kwargs.get('moe_backend', 'CUTLASS') + # quantized ckpt in modelopt format - quant_config_file = model_dir / 'hf_quant_config.json' - if quant_config_file.exists(): - with open(quant_config_file) as f: - quant_config_dict = json.load(f) - - json_quant_configs = quant_config_dict['quantization'] - - quant_config.quant_algo = json_quant_configs.get('quant_algo', None) - # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES - if quant_config.quant_algo == "fp8_pb_wo": - quant_config.quant_algo = 'FP8_BLOCK_SCALES' - quant_config.kv_cache_quant_algo = json_quant_configs.get( - 'kv_cache_quant_algo', None) - quant_config.group_size = json_quant_configs.get('group_size', None) - quant_config.exclude_modules = json_quant_configs.get( - 'exclude_modules', None) - - if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: - mixed_quant_config_file = model_dir / 'quant_cfg.json' - with open(mixed_quant_config_file) as fm: - mixed_quant_configs = json.load(fm) - # kv_cache_quant_algo is global regardless of MIXED_PRECISION - kv_cache_quant_algo = mixed_quant_configs[ - 'kv_cache_quant_algo'] - mixed_quant_configs = mixed_quant_configs[ - 'quantized_layers'] - if kv_cache_quant_algo is not None and quant_config.kv_cache_quant_algo is not None: - if kv_cache_quant_algo != quant_config.kv_cache_quant_algo: - raise RuntimeError( - f"The kvcache config in 'quant_cfg.json', {kv_cache_quant_algo}," - f"is different from 'hf_quant_config.json', {quant_config.kv_cache_quant_algo}!" - ) - kv_cache_quant_algo = kv_cache_quant_algo or quant_config.kv_cache_quant_algo - - for layer in mixed_quant_configs: - config = QuantConfig() - config.kv_cache_quant_algo = kv_cache_quant_algo - config.quant_algo = mixed_quant_configs[layer][ - 'quant_algo'] - config.group_size = mixed_quant_configs[layer].get( - 'group_size', None) - mixed_quant_configs[layer] = config - layer_quant_config = mixed_quant_configs - elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: - if quant_config.group_size is None: - quant_config.group_size = 128 - - if kwargs.get( - 'moe_backend' - ) == 'TRTLLM' and quant_config.quant_algo == "FP8_BLOCK_SCALES" and quant_config.exclude_modules is None: - quant_config.exclude_modules = [ - "*kv_b_proj*", "*k_b_proj*", "*eh_proj" - ] - + if (quant_config_file := model_dir / 'hf_quant_config.json').exists(): + quant_config, layer_quant_config = cls.load_modelopt_quant_config( + quant_config_file, model_dir, moe_backend) # quantized ckpt in other formats elif hasattr(pretrained_config, "quantization_config"): hf_quant_config = pretrained_config.quantization_config - # DeepSeek V3 FP8 ckpt - if hf_quant_config.get( - "quant_method") == "fp8" and hf_quant_config.get( - "weight_block_size", []): - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES - if kwargs.get('moe_backend') == 'TRTLLM': - # TODO: This is a hack. Remove after fp8 bmm is integrated. - quant_config.exclude_modules = [ - "*kv_b_proj*", "*k_b_proj*", "*eh_proj" - ] - else: - quant_config.exclude_modules = ["*eh_proj"] - - block_size = hf_quant_config.get("weight_block_size", []) - assert tuple(block_size) == ( - 128, - 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" - quant_config.group_size = block_size[0] + quant_config, layer_quant_config = cls.load_hf_quant_config( + hf_quant_config, moe_backend) + elif (quant_config_file := model_dir / 'dtypes.json').exists(): + quant_config, layer_quant_config = cls.load_quant_config_from_dtypes_json( + quant_config_file, moe_backend) model_config = cls(pretrained_config=pretrained_config, quant_config=quant_config, diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index e4da7aff5a..4f7aa39330 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -7,6 +7,8 @@ from .modeling_deepseekv3 import DeepseekV3ForCausalLM from .modeling_exaone4 import Exaone4ForCausalLM from .modeling_gemma3 import Gemma3ForCausalLM from .modeling_gemma3vl import Gemma3VLM +from .modeling_gpt_oss import GptOssForCausalLM +from .modeling_hunyuan_moe import HunYuanMoEV1ForCausalLM from .modeling_hyperclovax import HCXVisionForCausalLM from .modeling_llama import LlamaForCausalLM from .modeling_llava_next import LlavaNextModel @@ -37,6 +39,7 @@ __all__ = [ "Gemma3ForCausalLM", "Gemma3VLM", "HCXVisionForCausalLM", + "HunYuanMoEV1ForCausalLM", "LlamaForCausalLM", "LlavaNextModel", "Mistral3VLM", @@ -58,6 +61,7 @@ __all__ = [ "Qwen2_5_VLModel", "Qwen3ForCausalLM", "Qwen3MoeForCausalLM", + "GptOssForCausalLM", ] if transformers.__version__ >= "4.45.1": diff --git a/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py index c1bfec0144..10130b087e 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py @@ -67,8 +67,8 @@ class BaseCheckpointLoader(ABC): f"available formats are: {CHECKPOINT_LOADER_FORMAT_DEFAULT_MAPPING.keys()}" ) - def get_initilized_weight_mapper(self, model: nn.Module, - config: ModelConfig) -> BaseWeightMapper: + def get_initialized_weight_mapper(self, model: nn.Module, + config: ModelConfig) -> BaseWeightMapper: weight_mapper = None if self.weight_mapper is not None: self.weight_mapper.init_model_and_config(model, config) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 66ea5e3a0e..5210d341d4 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -39,7 +39,6 @@ from torch import nn from tqdm import tqdm from transformers import PretrainedConfig -from tensorrt_llm import logger from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType @@ -54,11 +53,11 @@ from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, MoEAllReduce, MoEAllReduceParams, allgather) from ..model_config import ModelConfig -from ..models.modeling_utils import ModelConfig, QuantConfig from ..modules.attention import MLA from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, TRTLLMGenFusedMoE, +from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, + MoEWeightLoadingMode, TRTLLMGenFusedMoE, create_moe, moe_load_balancer_set_repeated_for_next_layer) from ..modules.gated_mlp import GatedMLP @@ -66,10 +65,10 @@ from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..peft.lora.layer import LoraLayer -from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker +from ..speculative import MTPSpecMetadata, SpecMetadata from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor -from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, - EagerFusionConfig, filter_weights, +from .modeling_speculative import SpecDecOneEngineForCausalLM +from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights, register_auto_model) @@ -454,8 +453,14 @@ class Deepseekv3MoE(nn.Module): False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce. model_config=model_config, override_quant_config=override_quant_config, - aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap], - layer_idx=layer_idx) + aux_stream_dict=aux_stream_dict, + layer_idx=layer_idx, + # DS-R1 W4A8 is only supported through custom quantization script from + # examples/quantization/quantize_mixed_precision_moe.py + weight_loading_mode=(MoEWeightLoadingMode.W4A8_CUSTOM + if model_config.quant_config.quant_mode. + is_int4_weight_only_per_group() else + MoEWeightLoadingMode.VANILLA)) self.mapping = model_config.mapping @@ -535,7 +540,8 @@ class Deepseekv3MoE(nn.Module): router_logits = self.gate(hidden_states) routed_output = self.experts( - hidden_states_fp4 or hidden_states, + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states, router_logits, do_finalize=do_finalize, output_dtype=hidden_states.dtype, @@ -559,8 +565,9 @@ class Deepseekv3MoE(nn.Module): assert not self.use_dp def _compute_shared_output(): - shared_output = self.shared_experts(hidden_states_fp4 - or hidden_states) + shared_output = self.shared_experts( + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states) if self.shared_output_scale is not None: shared_output *= self.shared_output_scale return shared_output @@ -744,7 +751,7 @@ class DeepseekV3DecoderLayer(DecoderLayer): attn_metadata: AttentionMetadata, residual: torch.Tensor, **kwargs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -776,7 +783,7 @@ class DeepseekV3DecoderLayer(DecoderLayer): hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): return self.mlp( @@ -860,7 +867,7 @@ class DeepseekV3DecoderLayer(DecoderLayer): self, hidden_states: torch.Tensor, residual: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: if self.fusion_config.PRE_MLP_FUSION: act_fp4, act_sf, residual = self.allreduce( @@ -964,7 +971,7 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer): all_rank_num_tokens: Optional[List[int]] = None, all_rank_max_num_tokens: Optional[int] = None, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: def norm_embeds(): return self.enorm(embed_tokens(input_ids)) #emdedding @@ -1050,11 +1057,12 @@ class DeepseekV3Model(DecoderModel): config = model_config.pretrained_config self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers - aux_stream_list = [torch.cuda.Stream() for _ in range(2)] + aux_stream_list = [torch.cuda.Stream() for _ in range(3)] self.aux_stream_dict = { AuxStreamType.Attention: aux_stream_list[0], AuxStreamType.MoeShared: aux_stream_list[0], AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], + AuxStreamType.MoeBalancer: aux_stream_list[2], } self.embed_tokens = Embedding( @@ -1078,6 +1086,8 @@ class DeepseekV3Model(DecoderModel): input_ids: Optional[torch.IntTensor] = None, position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( @@ -1102,8 +1112,8 @@ class DeepseekV3Model(DecoderModel): @register_auto_model("DeepseekV3ForCausalLM") -class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, - PretrainedConfig]): +class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, + PretrainedConfig]): def __init__(self, model_config: ModelConfig[PretrainedConfig]): # Rename some keys of quant_config_dict to support legacy checkpoints @@ -1118,10 +1128,9 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, model_config._frozen = False model_config.quant_config_dict = quant_config_dict model_config._frozen = True - super().__init__(DeepseekV3Model(model_config), - config=model_config, - hidden_size=model_config.pretrained_config.hidden_size, - vocab_size=model_config.pretrained_config.vocab_size) + + super().__init__(model=DeepseekV3Model(model_config), + model_config=model_config) self.model_nextn = 0 if model_config.spec_config is not None: @@ -1131,23 +1140,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: moe_load_balancer_set_repeated_for_next_layer(model_nextn) - mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers, - self.model.aux_stream_dict) - self.model.layers.append(mtp_layer) - self.epilogue.append(mtp_layer) - self.mtp_worker = MTPEagleWorker(model_config.spec_config, - model_config) else: - mtp_layers = nn.ModuleList([ - DeepseekV3MTP(model_config, - layer_idx + self.num_hidden_layers, - self.model.aux_stream_dict) - for layer_idx in range(model_nextn) - ]) - self.model.layers.extend(mtp_layers) - self.epilogue.extend(mtp_layers) - self.mtp_worker = MTPWorker(model_config.spec_config, - model_config) # modify the QuantConfig to support duplicated mtp layers if model_config.quant_config.exclude_modules is not None: extend_exclude_modules = [] @@ -1165,7 +1158,9 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, ckpt_prefix, model_prefix)) self.model_config.quant_config.exclude_modules.extend( extend_exclude_modules) - self.epilogue.append(self.mtp_worker) + self.model.layers.extend(self.draft_model.mtp_layers) + self.epilogue.extend(self.draft_model.mtp_layers) + self.epilogue.append(self.spec_worker) def forward( self, @@ -1178,40 +1173,13 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, **kwargs, ) -> torch.Tensor: attn_metadata.num_generations_per_batch = self.model_nextn + 1 - hidden_states = self.model( - input_ids=input_ids, - attn_metadata=attn_metadata, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - ) - - if spec_metadata and spec_metadata.spec_dec_mode.is_mtp(): - # get logits - logits = self.logits_processor.forward( - hidden_states[spec_metadata.gather_ids], - self.lm_head, - attn_metadata, - True, - ) - # get accepted tokens and next draft tokens - return self.mtp_worker( - input_ids=input_ids, - position_ids=position_ids, - hidden_states=hidden_states, - logits=logits, - lm_head=self.lm_head, - embed_tokens=self.model.embed_tokens, - attn_metadata=attn_metadata, - spec_metadata=spec_metadata, - mtp_layers=self.model.layers[self.num_hidden_layers:]) - else: - logits = self.logits_processor.forward( - hidden_states, - self.lm_head, - attn_metadata, - return_context_logits, - ) - return logits + return super().forward(attn_metadata=attn_metadata, + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + spec_metadata=spec_metadata, + return_context_logits=return_context_logits, + **kwargs) def load_weights(self, weights: Dict): @@ -1344,21 +1312,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.named_modules()) - if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( - ) and get_sm_version() == 100: - for name in list(weights.keys()): - # Use ".experts." to exclude shared_experts. - if name.endswith( - "weight_scale_inv") and ".experts." not in name: - weight_name = name.replace("weight_scale_inv", "weight") - logger.debug(f"Resmoothing {weight_name}") - weight = weights[weight_name][:] - scale = weights[name][:] - weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( - weight, scale) - weights[weight_name] = weights[weight_name].cpu() - weights[name] = weights[name].cpu() - for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): if len(module._parameters) > 0: @@ -1480,12 +1433,15 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( ) and get_sm_version() == 100 and hasattr( module, "weight_scale"): + weight, weight_scale = resmooth_to_fp8_e8m0( + module.weight, module.weight_scale) transfromed_scale = transform_sf_into_required_layout( - module.weight_scale, - mn=module.weight.shape[0], - k=module.weight.shape[1], + weight_scale, + mn=weight.shape[0], + k=weight.shape[1], recipe=(1, 128, 128), is_sfa=False) + module.weight = nn.Parameter(weight, requires_grad=False) module.weight_scale = nn.Parameter(transfromed_scale, requires_grad=False) diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index 699e464151..e305b82dba 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -10,7 +10,9 @@ from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper from ..._utils import nvtx_range -from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, +from ...inputs import (ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger from ...sampling_params import SamplingParams @@ -137,7 +139,13 @@ class Gemma3MultiModalProjector(torch.nn.Module): @register_auto_model("Gemma3ForConditionalGeneration") -@register_input_processor(Gemma3InputProcessor, model_type="gemma3") +@register_input_processor( + Gemma3InputProcessor, + model_type="gemma3", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={"image": "<start_of_image>"}, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + )) class Gemma3VLM(PreTrainedModel): def __init__(self, model_config: ModelConfig[Gemma3Config]): diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py new file mode 100644 index 0000000000..5ea69fefb6 --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -0,0 +1,912 @@ +import os +from typing import Dict, Optional, Tuple + +import torch +from torch import nn +from torch.nn.parameter import Parameter +from tqdm import tqdm +from transformers import GptOssConfig + +from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType + +from ..attention_backend import AttentionMetadata +from ..attention_backend.interface import (PositionalEmbeddingParams, + PredefinedAttentionMask, RopeParams) +from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, + allgather) +from ..model_config import ModelConfig +from ..modules.attention import Attention +from ..modules.decoder_layer import DecoderLayer +from ..modules.embedding import Embedding + +# isort and yapf will fight against each other here, so we disable isort +# isort: off +from ..modules.fused_moe import (MoE, MoEWeightLoadingMode, + RenormalizeMoeRoutingMethod, TritonFusedMoE, + TRTLLMGenFusedMoE, create_moe) +from ..modules.fused_moe.routing import (get_cached_perfect_router_logits, + precompute_common_perfect_router_logits + ) +# isort: on +from ..modules.linear import Linear, TensorParallelMode +from ..modules.rms_norm import RMSNorm +from ..speculative import SpecMetadata +from ..utils import Fp4QuantizedTensor +from .modeling_speculative import SpecDecOneEngineForCausalLM +from .modeling_utils import (DecoderModel, duplicate_kv_weight, filter_weights, + register_auto_model) + + +class AttentionBlock(Attention): + + def __init__( + self, + config: ModelConfig[GptOssConfig], + layer_idx: int = 0, + ): + pretrained_config = config.pretrained_config + + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=RopeParams( + dim=pretrained_config.head_dim, + theta=pretrained_config.rope_theta, + scale_type=RotaryScalingType.yarn, + scale=pretrained_config.rope_scaling['factor'], + max_positions=pretrained_config.max_position_embeddings, + original_max_positions=pretrained_config. + rope_scaling['original_max_position_embeddings'], + beta_fast=pretrained_config.rope_scaling['beta_fast'], + beta_slow=pretrained_config.rope_scaling['beta_slow'], + duplicate_data=False), + is_neox=False, + ) + + super().__init__( + hidden_size=pretrained_config.hidden_size, + num_attention_heads=pretrained_config.num_attention_heads, + num_key_value_heads=pretrained_config.num_key_value_heads, + max_position_embeddings=pretrained_config.max_position_embeddings, + bias=True, + pos_embd_params=pos_embd_params, + layer_idx=layer_idx, + dtype=pretrained_config.torch_dtype, + dense_bias=True, + config=config, + q_scaling=1.0, + attention_chunk_size=None, + ) + + # Only apply sliding window to every other layer + self.sliding_window = pretrained_config.sliding_window if layer_idx % 2 == 0 else None + + self.sinks = Parameter( + torch.empty(pretrained_config.num_attention_heads // self.tp_size, + dtype=torch.float32)) + self.norm = RMSNorm(hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype) + + def forward( + self, + position_ids: Optional[torch.LongTensor], + hidden_states: torch.Tensor | Fp4QuantizedTensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None, + attention_mask: PredefinedAttentionMask = PredefinedAttentionMask. + CAUSAL, + all_reduce_params: Optional[AllReduceParams] = None, + lora_params: Optional[dict] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + + attention_window_size = self.sliding_window + attn_output = super().forward( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_mask=attention_mask, + all_reduce_params=all_reduce_params, + lora_params=lora_params, + attention_window_size=attention_window_size, + attention_sinks=self.sinks.data, + **kwargs) + return attn_output, residual + + def load_weights(self, weights: Dict): + sinks = weights[0]['sinks'][self.num_heads * + self.tp_rank:self.num_heads * + (self.tp_rank + 1)] + self.sinks.data = sinks.to(torch.float32).to("cuda") + + +class MLPBlock(torch.nn.Module): + + def __init__( + self, + config: ModelConfig[GptOssConfig], + layer_idx: int, + reduce_results: bool = True, + ): + super().__init__() + + self.config = config # Store config as instance variable + pretrained_config = config.pretrained_config + self.num_experts = pretrained_config.num_local_experts + self.layer_idx = layer_idx + self.enable_attention_dp = config.mapping.enable_attention_dp + self.mapping = config.mapping + + self.norm = RMSNorm(hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype) + + self.gate = Linear( + in_features=pretrained_config.hidden_size, + out_features=pretrained_config.num_local_experts, + bias=True, + dtype=pretrained_config.torch_dtype, + use_custom_cublas_mm= + False, # TODO: check perf & cublass mm can not support bias. + ) + + self.routing_method = RenormalizeMoeRoutingMethod( + top_k=pretrained_config.num_experts_per_tok) + self.swiglu_alpha = torch.tensor( + [1.702] * (self.num_experts // config.mapping.moe_ep_size), + dtype=torch.float32).cuda() + self.swiglu_beta = torch.tensor( + [1.0] * (self.num_experts // config.mapping.moe_ep_size), + dtype=torch.float32).cuda() + self.swiglu_limit = torch.tensor( + [7.0] * (self.num_experts // config.mapping.moe_ep_size), + dtype=torch.float32).cuda() + # Prepare MoE creation parameters + moe_params = { + 'routing_method': self.routing_method, + 'num_experts': pretrained_config.num_local_experts, + 'hidden_size': pretrained_config.hidden_size, + 'intermediate_size': pretrained_config.intermediate_size, + 'dtype': pretrained_config.torch_dtype, + 'reduce_results': not self.enable_attention_dp and reduce_results, + 'model_config': config, + 'weight_loading_mode': MoEWeightLoadingMode.FUSED_GATE_UP_PROJ, + 'bias': True, + 'swiglu_alpha': self.swiglu_alpha, + 'swiglu_beta': self.swiglu_beta, + 'swiglu_limit': self.swiglu_limit + } + + self.experts = create_moe(**moe_params) + + # Perfect router caching - precompute common logits if enabled + if os.environ.get('ENABLE_PERFECT_ROUTER', '0') == '1': + precompute_common_perfect_router_logits( + num_experts=pretrained_config.num_local_experts, + experts_per_token=pretrained_config.num_experts_per_tok, + moe_ep_size=config.mapping.moe_ep_size, + dtype=pretrained_config.torch_dtype) + + @staticmethod + def swiglu(x, alpha: float = 1.702): + """ + This function is not really used in self.forward(), it's kept here for reference. + It's implemented as part of the MoE kernels. + """ + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + 1) + + def _create_ideal_expert_load_balanced_logits( + self, num_tokens: int, num_experts: int, + device: torch.device) -> torch.Tensor: + """ + Create ideal logits that produce GPU-aware load balanced expert assignment. + This method now uses the global cache to access precomputed logits to optimize performance. + """ + pretrained_config = self.config.pretrained_config + + # Use global cached logits + return get_cached_perfect_router_logits( + num_tokens=num_tokens, + num_experts=num_experts, + experts_per_token=pretrained_config.experts_per_token, + moe_ep_size=self.config.mapping.moe_ep_size, + device=device, + dtype=pretrained_config.torch_dtype) + + def forward_normal( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + orig_shape = x.shape + hidden_dim = orig_shape[-1] + x = x.view(-1, hidden_dim) + + # t = self.norm(x) was done in the parent block + t = x + + g = self.gate(t) + # Use ideal load balanced logits if enabled, otherwise use gate output + if os.environ.get('ENABLE_PERFECT_ROUTER', '0') == '1': + # WARNING: This discards the learned gate output and uses ideal logits for perfect load balancing + # Only use this for testing load balancing strategies, not for actual inference + num_tokens, num_experts = g.shape + g = self._create_ideal_expert_load_balanced_logits( + num_tokens=num_tokens, num_experts=num_experts, device=x.device) + + # When attention_dp is not enabled, don't pass those parameters + expert_output = self.experts(x=t, router_logits=g) + + expert_output = expert_output.view(orig_shape) + return expert_output, residual + + def forward_attn_dp( + self, + x: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + orig_shape = x.shape + hidden_dim = orig_shape[-1] + x = x.view(-1, hidden_dim) + + # t = self.norm(x) was done in the parent block + t = x + + # Get attention_dp parameters + all_rank_num_tokens = attn_metadata.all_rank_num_tokens + all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens + + if self.mapping.tp_size > 1 and all_rank_num_tokens is not None: + if (isinstance(self.experts, (TRTLLMGenFusedMoE, TritonFusedMoE))): + t = allgather(t, self.mapping, dim=0, sizes=all_rank_num_tokens) + + g = self.gate(t) + # Use ideal load balanced logits if enabled, otherwise use gate output + if os.environ.get('ENABLE_PERFECT_ROUTER', '0') == '1': + # WARNING: This discards the learned gate output and uses ideal logits for perfect load balancing + # Only use this for testing load balancing strategies, not for actual inference + # The gate is still computed to maintain realistic performance measurement + num_tokens, num_experts = g.shape + g = self._create_ideal_expert_load_balanced_logits( + num_tokens=num_tokens, num_experts=num_experts, device=x.device) + + # Let CutlassFusedMoE handle allgather internally + # Pass the normalized tensor (t) as input to experts, not x + expert_output = self.experts( + x=t, + router_logits=g, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=all_rank_max_num_tokens, + use_dp_padding=False) + + expert_output = expert_output.view(orig_shape) + return expert_output, residual + + def forward( + self, + x: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.enable_attention_dp: + return self.forward_attn_dp(x, attn_metadata, residual) + else: + return self.forward_normal(x, residual) + + +class TransformerBlock(DecoderLayer): + + def __init__( + self, + config: ModelConfig[GptOssConfig], + layer_idx: int, + ): + super().__init__() + self.layer_idx = layer_idx + + mapping = config.mapping + self.enable_attn_dp = mapping.enable_attention_dp + self.is_tp = mapping.has_tp() and not self.enable_attn_dp + + pretrained_config = config.pretrained_config + self.input_layernorm = RMSNorm( + hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype) + + self.attn = AttentionBlock(config, layer_idx) + + self.post_attention_layernorm = RMSNorm( + hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype) + + self.mlp = MLPBlock(config, layer_idx, reduce_results=not self.is_tp) + + self.mapping = config.mapping + + self.next_layer_layernorm = RMSNorm( + hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype) + + # setup for tp + self.allreduce = AllReduce(mapping=config.mapping, + strategy=config.allreduce_strategy, + dtype=config.pretrained_config.torch_dtype) + + def forward_normal( + self, + position_ids: torch.LongTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = ..., + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + x, residual = self.attn(position_ids, + hidden_states, + attn_metadata, + residual=residual, + **kwargs) + x, residual = self.post_attention_layernorm(x, residual) + + x, residual = self.mlp(x, attn_metadata, residual) + + if spec_metadata is not None: + spec_metadata.maybe_capture_hidden_states(self.layer_idx, x, + residual) + + x, residual = self.next_layer_layernorm(x, residual) + return x, residual + + def forward_tp( + self, + position_ids: torch.LongTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = ..., + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + x, residual = self.attn( + position_ids, + hidden_states, + attn_metadata, + residual=residual, + all_reduce_params=AllReduceParams(enable_allreduce=False), + **kwargs) + + x, residual = self.allreduce( + x, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + + x, residual = self.mlp(x, attn_metadata, residual) + + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + # In eagle3 mode, we capture the value in the boundary of decoder layer. + # If fusing rms in the next layer, the value is not correct. Thus, if + # this layer will be captured, we should not fuse the rms in the next + # layer. + x = self.allreduce(x, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.NONE, + trigger_completion_at_end=False, + )) + spec_metadata.maybe_capture_hidden_states(self.layer_idx, x, + residual) + x, residual = self.next_layer_layernorm(x, residual) + else: + x, residual = self.allreduce( + x, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + + return x, residual + + def forward( + self, + position_ids: torch.LongTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = ..., + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.is_tp: + _forward = self.forward_tp + else: + _forward = self.forward_normal + return _forward(position_ids, + hidden_states, + attn_metadata, + residual, + spec_metadata=spec_metadata, + **kwargs) + + +class Transformer(DecoderModel): + + def __init__(self, model_config: ModelConfig[GptOssConfig]): + super().__init__(model_config) + config = self.model_config + + # Triton MoE kernels require installing Triton main branch, + # which may be incompatible with torch.compile due to version mismatch. + enable_torch_compile_for_embedding = model_config.moe_backend != "TRITON" + + if model_config.mapping.enable_attention_dp: + # When attention_dp is enabled, we cannot do all_reduce since + # the problem size of different ranks are different. + # So, we don't do parallelism here. + self.embedding = Embedding( + config.pretrained_config.vocab_size, + config.pretrained_config.hidden_size, + dtype=config.pretrained_config.torch_dtype, + enable_torch_compile_for_embedding= + enable_torch_compile_for_embedding) + else: + self.embedding = Embedding( + config.pretrained_config.vocab_size, + config.pretrained_config.hidden_size, + dtype=config.pretrained_config.torch_dtype, + mapping=config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + gather_output=True, + enable_torch_compile_for_embedding= + enable_torch_compile_for_embedding, + ) + # For modeling_speculative, different name expected + self.embed_tokens = self.embedding + self.block = nn.ModuleList([ + TransformerBlock( + model_config, + layer_idx, + ) for layer_idx in range(config.pretrained_config.num_hidden_layers) + ]) + self.norm = RMSNorm( + hidden_size=config.pretrained_config.hidden_size, + eps=config.pretrained_config.rms_norm_eps, + dtype=config.pretrained_config.torch_dtype, + ) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + mrope_config: Optional[Tuple[torch.Tensor, int]] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + hidden_states = inputs_embeds or self.embedding(input_ids) + + residual = None + for block in self.block: + hidden_states, residual = block( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + mrope_config=mrope_config, + spec_metadata=spec_metadata, + ) + + return hidden_states + + +@register_auto_model("GptOssForCausalLM") +class GptOssForCausalLM(SpecDecOneEngineForCausalLM[Transformer, GptOssConfig]): + + params_map = { + # TRTLLM module name : GptOss module name + "qkv_proj": "qkv", + "o_proj": "out", + "lm_head": "unembedding", + } + + hf_params_map = { + # TRTLLM module name : HuggingFace module name + "embedding": "embed_tokens", + # Order matters for attn.norm and attn. + 'attn.norm': 'input_layernorm', + 'attn': 'self_attn', + 'mlp.norm': 'post_attention_layernorm', + 'block': 'layers', + 'gate': 'router', + } + + def __init__( + self, + model_config: ModelConfig[GptOssConfig], + ): + # Map config to HF format. + if hasattr(model_config.pretrained_config, 'num_experts'): + model_config.pretrained_config.num_local_experts = model_config.pretrained_config.num_experts + model_config.pretrained_config.num_experts_per_tok = model_config.pretrained_config.experts_per_token + model_config.pretrained_config.rope_scaling = { + 'factor': + model_config.pretrained_config.rope_scaling_factor, + 'beta_fast': + model_config.pretrained_config.rope_ntk_beta, + 'beta_slow': + model_config.pretrained_config.rope_ntk_alpha, + 'original_max_position_embeddings': + model_config.pretrained_config.initial_context_length, + } + if model_config.pretrained_config.torch_dtype is None: + model_config.pretrained_config.torch_dtype = torch.bfloat16 + + super().__init__( + Transformer(model_config), + model_config=model_config, + ) + + def __post_init__(self): + # Do not call super().__post_init__() + params_map_reverse = {v: k for k, v in self.params_map.items()} + + quant_config = self.model_config.quant_config + if quant_config.exclude_modules: + for i, module in enumerate(quant_config.exclude_modules): + names = module.split(".") + if names[-1] in params_map_reverse: + names[-1] = params_map_reverse[names[-1]] + prefix = [] if names[0] == "model" else ["model"] + quant_config.exclude_modules[i] = '.'.join(prefix + names) + + super().apply_quant_config_exclude_modules() + + for _, module in self.named_modules(): + if callable(getattr(module, "create_weights", None)): + module.create_weights() + + def load_weights(self, weights: Dict): + is_ori_model = True + for k, v in weights.items(): + if 'q_proj' in k: + is_ori_model = False + + if is_ori_model: + self.load_ori_weights(weights) + else: + self.load_hf_weights(weights) + + for idx, layer in enumerate( + self.model.block[:self.config.num_hidden_layers]): + if idx == 0: + layer.input_layernorm = layer.attn.norm + + layer.post_attention_layernorm = layer.mlp.norm + + if idx == self.config.num_hidden_layers - 1: + layer.next_layer_layernorm = self.model.norm + else: + layer.next_layer_layernorm = self.model.block[idx + 1].attn.norm + + def load_hf_weights(self, weights: Dict): + num_expert = self.config.num_local_experts + + for name, module in tqdm(list(self.named_modules()), + desc="Loading weights"): + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + + module_weights = {} + for k, v in self.hf_params_map.items(): + name = name.replace(k, v) + module_weights = filter_weights(name, weights) + + if isinstance(module, MoE): + try: + # For BF16 ckpt. + # Deinterleave for gate and up. + gate_up_weight = module_weights['gate_up_proj'] + gate, up = gate_up_weight[:, :, ::2], gate_up_weight[:, :, + 1::2] + gate_up_weight = torch.cat([gate, up], dim=-1) + gate_up_bias = module_weights['gate_up_proj_bias'] + gate, up = gate_up_bias[:, ::2], gate_up_bias[:, 1::2] + gate_up_bias = torch.cat([gate, up], dim=-1) + moe_weights = { + 'gate_up_proj': [ + gate_up_weight.to(self.model.dtype)[i, :, :] + for i in range(num_expert) + ], + 'down_proj': [ + module_weights['down_proj'][i, :, :].to( + self.model.dtype) for i in range(num_expert) + ], + 'gate_up_proj.bias': + [gate_up_bias[i, :] for i in range(num_expert)], + 'down_proj.bias': [ + module_weights['down_proj_bias'][i, :] + for i in range(num_expert) + ] + } + except: + # For MXFP4 ckpt. + # Deinterleave for gate and up. + gate_up_weight = module_weights[ + 'gate_up_proj_blocks'].flatten(-2, -1) + gate_weight, up_weight = gate_up_weight[:, :: + 2, :], gate_up_weight[:, + 1:: + 2, :] + gate_up_weight = torch.cat([gate_weight, up_weight], dim=-2) + gate_up_bias = module_weights['gate_up_proj_bias'] + gate_bias, up_bias = gate_up_bias[:, :: + 2], gate_up_bias[:, 1::2] + gate_up_bias = torch.cat([gate_bias, up_bias], dim=-1) + gate_up_weight_scale = module_weights['gate_up_proj_scales'] + gate_weight_scale, up_weight_scale = gate_up_weight_scale[:, :: + 2, :], gate_up_weight_scale[:, + 1:: + 2, :] + gate_up_weight_scale = torch.cat( + [gate_weight_scale, up_weight_scale], dim=-2) + moe_weights = { + 'gate_up_proj': [ + gate_up_weight[i, :, :].transpose(0, 1) + for i in range(num_expert) + ], + 'down_proj': [ + module_weights['down_proj_blocks'].flatten( + -2, -1)[i, :, :].transpose(0, 1) + for i in range(num_expert) + ], + 'gate_up_proj.bias': + [gate_up_bias[i, :] for i in range(num_expert)], + 'down_proj.bias': [ + module_weights['down_proj_bias'][i, :] + for i in range(num_expert) + ], + 'gate_up_proj_weight_scale': [ + gate_up_weight_scale[i, :, :].transpose(0, 1) + for i in range(num_expert) + ], + 'down_proj_weight_scale': [ + module_weights['down_proj_scales'] + [i, :, :].transpose(0, 1) for i in range(num_expert) + ] + } + + if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4': + for i in range(num_expert): + moe_weights[ + f"{i}.w1.weight_scale_inv"] = gate_weight_scale[ + i, :, :] + moe_weights[ + f"{i}.w3.weight_scale_inv"] = up_weight_scale[ + i, :, :] + moe_weights[ + f"{i}.w2.weight_scale_inv"] = module_weights[ + 'down_proj_scales'][i, :, :] + + module.load_weights(weights=[moe_weights]) + elif hasattr(module, "load_weights"): + if 'qkv' in name: + # For qkv_proj + q_weight_bias = filter_weights( + name.replace('qkv_proj', 'q_proj'), weights) + k_weight_bias = filter_weights( + name.replace('qkv_proj', 'k_proj'), weights) + v_weight_bias = filter_weights( + name.replace('qkv_proj', 'v_proj'), weights) + module.load_weights( + weights=[q_weight_bias, k_weight_bias, v_weight_bias]) + else: + # For o_proj, sinks. + module.load_weights(weights=[module_weights]) + else: + # Load four LN weights (attn.norm, mlp.norm, input_layernorm, post_attention_layernorm). + if 'next_layer_layernorm' in name: + continue + + for n, p in module._parameters.items(): + if p is not None: + p.data.copy_(module_weights[n][:]) + + def load_ori_weights(self, weights: Dict): + head_dim = self.config.head_dim + num_q_head = self.config.num_attention_heads + num_kv_head = self.config.num_key_value_heads + num_expert = self.config.num_local_experts + enable_attention_dp = self.model_config.mapping.enable_attention_dp + tp_size = self.model_config.mapping.tp_size + + for name, module in tqdm(list(self.named_modules()), + desc="Loading weights"): + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + names = name.split(".") + module_weights = {} + if names[-1] in self.params_map: + names[-1] = self.params_map[names[-1]] + + # Drop the first "model" prefix + if names[0] == 'model': + name = '.'.join(names[1:]) + else: + name = '.'.join(names) + module_weights = filter_weights(name, weights) + if isinstance(module, MoE): + # [num_experts, intermediate_size * 2, hidden_size] + gate_up_proj = filter_weights(name.replace("experts", "mlp1"), + weights) + # [num_experts, intermediate_size, hidden_size] + down_proj = filter_weights(name.replace("experts", "mlp2"), + weights) + try: + # Official MXFP4 ckpt. + gate_up_weight = gate_up_proj['weight.blocks'].flatten( + -2, -1) + gate, up = gate_up_weight[:, ::2, :], gate_up_weight[:, 1:: + 2, :] + gate_up_weight = torch.cat([gate, up], dim=-2) + gate_up_bias = gate_up_proj['bias'] + gate, up = gate_up_bias[:, ::2], gate_up_bias[:, 1::2] + gate_up_bias = torch.cat([gate, up], dim=-1) + moe_weights = { + 'gate_up_proj': [ + gate_up_weight[i, :, :].transpose(0, 1) + for i in range(num_expert) + ], + 'down_proj': [ + down_proj['weight.blocks'].flatten( + -2, -1)[i, :, :].transpose(0, 1) + for i in range(num_expert) + ], + 'gate_up_proj.bias': + [gate_up_bias[i, :] for i in range(num_expert)], + 'down_proj.bias': + [down_proj['bias'][i, :] for i in range(num_expert)] + } + except: + # For BF16 ckpt. + moe_weights = { + 'gate_up_proj': [ + gate_up_proj['weight'][i, :, :].transpose(0, 1).to( + self.model.dtype) for i in range(num_expert) + ], + 'down_proj': [ + down_proj['weight'][i, :, :].transpose(0, 1).to( + self.model.dtype) for i in range(num_expert) + ], + 'gate_up_proj.bias': + [gate_up_proj['bias'][i, :] for i in range(num_expert)], + 'down_proj.bias': + [down_proj['bias'][i, :] for i in range(num_expert)] + } + # Only for Official MXFP4 ckpt. + if 'weight.scales' in gate_up_proj: + gate_up_weight_scale = gate_up_proj['weight.scales'] + gate, up = gate_up_weight_scale[:, :: + 2, :], gate_up_weight_scale[:, + 1:: + 2, :] + gate_up_weight_scale = torch.cat([gate, up], dim=-2) + moe_weights['gate_up_proj_weight_scale'] = [ + gate_up_weight_scale[i, :, :].transpose(0, 1) + for i in range(num_expert) + ] + + if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4': + for i in range(num_expert): + moe_weights[f"{i}.w1.weight_scale_inv"] = gate[ + i, :, :] + moe_weights[f"{i}.w3.weight_scale_inv"] = up[ + i, :, :] + + if 'weight.scales' in down_proj: + moe_weights['down_proj_weight_scale'] = [ + down_proj['weight.scales'][i, :, :].transpose(0, 1) + for i in range(num_expert) + ] + + if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4': + for i in range(num_expert): + moe_weights[f"{i}.w2.weight_scale_inv"] = down_proj[ + 'weight.scales'][i, :, :] + + module.load_weights(weights=[moe_weights]) + elif hasattr(module, "load_weights"): + # Load Attention module weights. + if 'qkv' in name: + q_weight = module_weights['weight'][:head_dim * + num_q_head, :] + k_weight = module_weights['weight'][head_dim * + num_q_head:head_dim * + (num_q_head + + num_kv_head), :] + v_weight = module_weights['weight'][-head_dim * + num_kv_head:, :] + q_bias = module_weights['bias'][:head_dim * num_q_head] + k_bias = module_weights['bias'][head_dim * + num_q_head:head_dim * + (num_q_head + num_kv_head)] + v_bias = module_weights['bias'][-head_dim * num_kv_head:] + + # Handle KV weight duplication for GQA + tensors_need_duplication = ['weight', 'bias'] + if module.quant_config.quant_mode.has_mxfp4(): + tensors_need_duplication.append('weight_scale') + + # Duplicate KV weights if needed + tensor_parallel_size = tp_size if not enable_attention_dp else 1 + + k_weight_dict = {'weight': k_weight, 'bias': k_bias} + v_weight_dict = {'weight': v_weight, 'bias': v_bias} + + if 'weight_scale' in module_weights: + k_weight_dict['weight_scale'] = module_weights[ + 'weight_scale'][head_dim * num_q_head:head_dim * + (num_q_head + num_kv_head), :] + v_weight_dict['weight_scale'] = module_weights[ + 'weight_scale'][-head_dim * num_kv_head:, :] + + k_weight_dict = { + k: (duplicate_kv_weight( + weight=v, + num_kv_heads=num_kv_head, + tensor_parallel_size=tensor_parallel_size) + if k in tensors_need_duplication else v) + for k, v in k_weight_dict.items() + } + + v_weight_dict = { + k: (duplicate_kv_weight( + weight=v, + num_kv_heads=num_kv_head, + tensor_parallel_size=tensor_parallel_size) + if k in tensors_need_duplication else v) + for k, v in v_weight_dict.items() + } + + qkv_weights = [{ + 'weight': q_weight, + 'bias': q_bias + }, k_weight_dict, v_weight_dict] + module.load_weights(weights=qkv_weights) + else: + # Dense & gate & sinks + module.load_weights(weights=[module_weights]) + else: + # Load LN weights. + if names[-1].endswith("layernorm") and names[-3] == "block": + # skip loading weights for the fused norms + continue + for n, p in module._parameters.items(): + if p is not None: + p.data.copy_(module_weights[n.replace( + "weight", "scale")][:]) diff --git a/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py new file mode 100644 index 0000000000..6ebb6f7e53 --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py @@ -0,0 +1,433 @@ +from typing import Dict, Optional, Union + +import torch +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + +from tensorrt_llm._torch.distributed import AllReduceParams +from tensorrt_llm.functional import PositionEmbeddingType + +from ..attention_backend import AttentionMetadata +from ..attention_backend.interface import (PositionalEmbeddingParams, + PredefinedAttentionMask, RopeParams) +from ..model_config import ModelConfig +from ..modules.attention import Attention +from ..modules.decoder_layer import DecoderLayer +from ..modules.embedding import Embedding +from ..modules.fused_moe import (CutlassFusedMoE, RenormalizeMoeRoutingMethod, + VanillaMoE, create_moe) +from ..modules.gated_mlp import GatedMLP +from ..modules.linear import Linear, TensorParallelMode +from ..modules.multi_stream_utils import maybe_execute_in_parallel +from ..modules.rms_norm import RMSNorm +from ..utils import AuxStreamType, Fp4QuantizedTensor +from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, + duplicate_kv_weight, register_auto_model) + + +class HunyuanMoE(nn.Module): + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + aux_stream: torch.cuda.Stream, + ): + super().__init__() + config = model_config.pretrained_config + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.moe_intermediate_size = config.moe_intermediate_size[0] \ + if isinstance(config.moe_intermediate_size, list) else config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.moe_topk[0] \ + if isinstance(config.moe_topk, list) else config.moe_topk + self.enable_attention_dp = model_config.mapping.enable_attention_dp + + # moe gate (linear layer) only runs in half/full precision for now + self.gate = Linear(self.hidden_dim, + self.num_experts, + bias=False, + dtype=config.torch_dtype) + + reduce_results = True + + self.experts = create_moe( + num_experts=self.num_experts, + routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k), + hidden_size=self.hidden_dim, + intermediate_size=self.moe_intermediate_size, + aux_stream=aux_stream, + dtype=config.torch_dtype, + reduce_results=reduce_results, + model_config=model_config) + + self.shared_mlp = GatedMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=config.mlp_bias if hasattr(config, 'mlp_bias') else False, + dtype=config.torch_dtype, + config=model_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + assert hidden_states.shape[-1] == self.hidden_dim + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_dim) + + shared_expert_output = self.shared_mlp(hidden_states) + all_rank_num_tokens = attn_metadata.all_rank_num_tokens + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=False) + + final_hidden_states = shared_expert_output + final_hidden_states + + return final_hidden_states.view(orig_shape) + + +class HunYuanAttention(Attention): + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + use_qk_norm: bool = True, + nope_layer: bool = False, + aux_stream: Optional[torch.cuda.Stream] = None, + ): + config = model_config.pretrained_config + + self.use_rope = not nope_layer + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=RopeParams.from_config(config), + is_neox=True, + ) if self.use_rope else None + self.use_qk_norm = use_qk_norm + + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + bias=config.attention_bias, + pos_embd_params=pos_embd_params, + rope_fusion=not self.use_qk_norm, + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + ) + + self.head_dim = config.hidden_size // config.num_attention_heads + self.query_layernorm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.key_layernorm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.aux_stream = aux_stream + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + + def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], + v: Optional[torch.Tensor], position_ids: torch.Tensor): + q, k, v = self.split_qkv(q, k, v) + if position_ids is not None: + q, k, v = super().apply_rope(q, k, v, position_ids) + # Llama4 applies QK norm after RoPE. + if self.use_qk_norm: + q, k = self.apply_qk_norm(q, k) + + return q, k, v + + def apply_qk_norm(self, q, k): + + def q_l2norm(): + return self.query_layernorm(q.reshape(-1, self.head_dim)).reshape( + -1, self.q_size) + + def k_l2norm(): + return self.key_layernorm(k.reshape(-1, self.head_dim)).reshape( + -1, self.kv_size) + + q, k = maybe_execute_in_parallel( + q_l2norm, + k_l2norm, + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + + return q, k + + def forward( + self, + position_ids: Optional[torch.IntTensor], + hidden_states: Union[torch.Tensor, Fp4QuantizedTensor], + attn_metadata: AttentionMetadata, + attention_mask: PredefinedAttentionMask = PredefinedAttentionMask. + CAUSAL, + mrope_config: Optional[dict] = None, + all_reduce_params: Optional[AllReduceParams] = None, + lora_params: Optional[dict] = None, + **kwargs, + ) -> torch.Tensor: + assert lora_params is None, "LORA is not supported for HunYuanAttention" + return super().forward( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_mask=attention_mask, + mrope_config=mrope_config, + all_reduce_params=all_reduce_params, + lora_params=lora_params, + **kwargs, + ) + + +class HunYuanDecoderLayer(DecoderLayer): + + def __init__(self, model_config: ModelConfig[PretrainedConfig], + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): + super().__init__() + config = model_config.pretrained_config + self.layer_idx = layer_idx + + # attention + self.self_attn = HunYuanAttention( + model_config, + layer_idx=layer_idx, + ) + + is_experts_valid = ((isinstance(config.num_experts, int) + and config.num_experts > 1) + or (isinstance(config.num_experts, list) + and max(config.num_experts) > 1)) + is_moe_single_node = is_experts_valid and layer_idx >= config.moe_layer_num_skipped # only support one node yet + + if is_moe_single_node: + self.mlp = HunyuanMoE( + model_config, aux_stream_dict[AuxStreamType.MoeChunkingOverlap]) + else: + self.mlp = GatedMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=config.mlp_bias, + dtype=config.torch_dtype, + config=model_config) + + norm_type = getattr(config, 'norm_type', 'rms') + if norm_type == 'hf_rms' or norm_type == 'rms': + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + elif norm_type == 'fused' or norm_type == 'torch_nn': + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps) + else: + assert False, "other norm_type are not supported" + + def forward( + self, + position_ids: torch.LongTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + **kwargs, + ) + # Fully Connected + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, attn_metadata) + hidden_states = residual + hidden_states + return hidden_states + + +class HunYuanModel(DecoderModel): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(model_config) + config = model_config.pretrained_config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.aux_stream_dict = { + key: torch.cuda.Stream() + for key in [ + AuxStreamType.Attention, AuxStreamType.MoeShared, + AuxStreamType.MoeChunkingOverlap + ] + } + + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + mapping=model_config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + gather_output=True, + ) + + self.layers = nn.ModuleList([ + HunYuanDecoderLayer(model_config, layer_idx, self.aux_stream_dict) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + for layer_idx, decoder_layer in enumerate(self.layers): + kwargs['layer_idx'] = layer_idx + hidden_states = decoder_layer( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +@register_auto_model("HunYuanMoEV1ForCausalLM") +class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, + PretrainedConfig]): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(HunYuanModel(model_config), + config=model_config, + hidden_size=model_config.pretrained_config.hidden_size, + vocab_size=model_config.pretrained_config.vocab_size) + self._execution_stats = None + print("---debug model_config: ", model_config) + + def load_weights(self, weights: Dict): + tp_size = self.model_config.mapping.tp_size + head_dim = self.config.hidden_size // self.config.num_attention_heads + + def filter_weights(prefix, weights: Dict): + result = {} + for k, v in weights.items(): + if k.startswith(prefix): + new_k = k[len(prefix) + 1:] + result[new_k] = v + return result + + params_map = { + 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], + 'gate_up_proj': ['gate_proj', 'up_proj'] + } + for name, module in tqdm(list(self.named_modules()), + desc="Loading weights"): + if len(module._parameters) > 0: + # skip load weights if tie word embeddings is enabled and layer is lm_head + if self.config.tie_word_embeddings and name.startswith( + "lm_head"): + continue + names = name.split('.') + if names[-1] in params_map: + # model.layers.{idx}.mlp.shared_mlp.gate_up_proj or model.layers.{idx}.self_attn.qkv_proj + module_weights = [] + for new_name in params_map[names[-1]]: + fw = filter_weights('.'.join(names[:-1] + [new_name]), + weights) + if new_name in ['k_proj', 'v_proj']: + fw = { + k: + duplicate_kv_weight( + weight=v[:], + num_kv_heads=v[:].shape[0] // head_dim, + tensor_parallel_size=tp_size) + if k in ["weight", "bias"] else v + for k, v in fw.items() + } + module_weights.append(fw) + module.load_weights(weights=module_weights) + else: + name = name.replace('gate', 'gate.wg') + module_weights = filter_weights(name, weights) + if isinstance(module, CutlassFusedMoE) or isinstance( + module, VanillaMoE): + # model.layers.{idx}.mlp.experts + updated_module_weights = {} + for weight_name, weight_value in module_weights.items(): + new_weight_name = weight_name.replace( + "gate_proj", + "w1").replace("up_proj", + "w3").replace("down_proj", "w2") + updated_module_weights[ + new_weight_name] = weight_value + del module_weights + module.load_weights(weights=[updated_module_weights]) + elif hasattr(module, 'load_weights'): + # model.layers.{idx}.self_attn.o_proj or model.layers.{idx}.mlp.shared_mlp.down_proj + # or model.layers.{idx}.mlp.experts.gate + module.load_weights(weights=[module_weights]) + else: + for n, p in module._parameters.items(): + if p is not None: + p.data.copy_(module_weights[n][:]) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + return_context_logits: bool = False, + **kwargs, + ) -> torch.Tensor: + output = self.model( + input_ids=input_ids, + attn_metadata=attn_metadata, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + return self.logits_processor.forward( + output, + self.lm_head, + attn_metadata, + return_context_logits, + ) diff --git a/tensorrt_llm/_torch/models/modeling_hyperclovax.py b/tensorrt_llm/_torch/models/modeling_hyperclovax.py index 56d56f2443..a05784b9d8 100644 --- a/tensorrt_llm/_torch/models/modeling_hyperclovax.py +++ b/tensorrt_llm/_torch/models/modeling_hyperclovax.py @@ -15,7 +15,9 @@ from transformers.models.auto import CONFIG_MAPPING from tensorrt_llm.inputs.multimodal import MultimodalParams -from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, +from ...inputs import (ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger from ...sampling_params import SamplingParams @@ -961,7 +963,23 @@ class HCXVisionModel: @register_auto_model("HCXVisionForCausalLM") -@register_input_processor(HCXVisionInputProcessor, model_type="hyperclovax_vlm") +@register_input_processor( + HCXVisionInputProcessor, + model_type="hyperclovax_vlm", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": + ('<im_end>\n<|im_start|>user (mime) \n' + '{"type": "image/jpeg", "filename": ""}<|im_end|>\n' + '<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n' + '<|im_start|>image/aux\n' + '다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 ' + 'keyword와 bbox 위치입니다.bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 ' + '형태입니다. 참고하여 답변하세요. ' + '{"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}') + }, + placeholder_placement=MultimodalPlaceholderPlacement.AFTER_TEXT, + )) class HCXVisionForCausalLM(PreTrainedModel): def __init__(self, model_config: ModelConfig): diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 6ec6557961..a48d1cdbf6 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -20,7 +20,9 @@ from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import HfLoraLoader from tensorrt_llm.models.convert_utils import split_matrix_tp -from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, +from ...inputs import (ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...sampling_params import SamplingParams from ..attention_backend import AttentionMetadata @@ -74,6 +76,11 @@ class Llama4Attention(Attention): elif get_sm_version() <= 90 and model_config.spec_config is not None: # pre-Blackwell spec-dec kernel does not support attention_chunk_size = None + else: + # Disable chunked attention when max_seq_len is smaller than attention_chunk_size + # TODO: Remove this after all attention kernels in TRTLLM backend support chunked attention + if attention_chunk_size and model_config.max_seq_len and model_config.max_seq_len < attention_chunk_size: + attention_chunk_size = None super().__init__( hidden_size=config.hidden_size, @@ -165,22 +172,16 @@ class Llama4Attention(Attention): q, k, v = self.split_qkv(q, k, v) q = self._attention_scaling(q, position_ids) - out_scale = None - out_scale_sf = None - if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales: - out_scale = self.o_proj.inv_input_scale - if self.o_proj.has_nvfp4 and self.support_nvfp4_output: - out_scale_sf = self.o_proj.input_scale - q, k, v = self.convert_qkv(q, k, v) - attn_output = self.attn.forward(q, + attn_output = self.forward_impl(q, k, v, attn_metadata, - out_scale=out_scale, - out_scale_sf=out_scale_sf, - attention_mask=attention_mask, - mrope_config=mrope_config) + attention_mask, + None, + None, + mrope_config, + attention_sinks=None) attn_output = self.o_proj(attn_output, all_reduce_params=all_reduce_params) @@ -1168,7 +1169,13 @@ class Llama4InputProcessor(InputProcessor): @register_auto_model("Llama4ForConditionalGeneration") -@register_input_processor(Llama4InputProcessor, model_type="llama4") +@register_input_processor( + Llama4InputProcessor, + model_type="llama4", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={"image": "<|image|>"}, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + )) class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model, Llama4Config]): diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index ac26cb6473..8b52c7cf63 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -14,7 +14,9 @@ from transformers.models.llava_next.modeling_llava_next import ( from tensorrt_llm.inputs.multimodal import MultimodalParams -from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, +from ...inputs import (ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...llmapi.utils import download_hf_model from ...logger import logger @@ -263,7 +265,13 @@ class LlavaNextVisionModel(nn.Module): @register_auto_model("LlavaNextForConditionalGeneration") -@register_input_processor(LlavaNextInputProcessor, model_type="llava_next") +@register_input_processor( + LlavaNextInputProcessor, + model_type="llava_next", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={"image": "<image>"}, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + )) class LlavaNextModel(PreTrainedModel): config_class = LlavaNextConfig diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index f10ea5368c..2ee0ed4c62 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -29,7 +29,9 @@ from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm._torch.speculative import SpecMetadata from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor, - TextPrompt, register_input_processor) + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, + register_input_processor) from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.logger import logger @@ -269,8 +271,20 @@ class Mistral3InputProcessor(InputProcessor): @register_auto_model("Mistral3ForConditionalGeneration") -# The below informs the registry which input registry to create for this in `tensorrt_llm/llmapi/llm.py`. -@register_input_processor(Mistral3InputProcessor, model_type="mistral3") +@register_input_processor( + Mistral3InputProcessor, + model_type="mistral3", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "[IMG]", + }, + # NOTE: for mistral3 multimodal models, it does not strictly have to be before the text. + # Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/ + # src/mistral_common/tokens/tokenizers/base.py#L326 + # However, accuracy tests show that the model generates higher quality output when the image + # precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM). + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + )) class Mistral3VLM(PreTrainedModel): """Mistral3VLM implementation for TRTLLM. @@ -475,6 +489,7 @@ class Mistral3PatchMerger(torch.nn.Module): out_features=hidden_size, bias=False, dtype=config.torch_dtype, + mapping=model_config.mapping, ) @torch.inference_mode() @@ -539,6 +554,7 @@ class Mistral3MultiModalProjector(torch.nn.Module): out_features=config.text_config.hidden_size, bias=config.multimodal_projector_bias, dtype=dtype, + mapping=model_config.mapping, ) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = Linear( @@ -546,6 +562,7 @@ class Mistral3MultiModalProjector(torch.nn.Module): out_features=config.text_config.hidden_size, bias=config.multimodal_projector_bias, dtype=dtype, + mapping=model_config.mapping, ) @torch.inference_mode() diff --git a/tensorrt_llm/_torch/models/modeling_mixtral.py b/tensorrt_llm/_torch/models/modeling_mixtral.py index e16b82020b..21dcc20063 100644 --- a/tensorrt_llm/_torch/models/modeling_mixtral.py +++ b/tensorrt_llm/_torch/models/modeling_mixtral.py @@ -15,6 +15,7 @@ from ..modules.embedding import Embedding from ..modules.fused_moe import RenormalizeMoeRoutingMethod, create_moe from ..modules.linear import Linear from ..modules.rms_norm import RMSNorm +from ..utils import AuxStreamType from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, register_auto_model) @@ -49,7 +50,7 @@ class MixtralMoE(nn.Module): routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k), hidden_size=self.hidden_dim, intermediate_size=self.ffn_dim, - aux_stream=aux_stream, + aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream}, dtype=config.torch_dtype, reduce_results=reduce_results, model_config=model_config, @@ -160,6 +161,8 @@ class MixtralModel(DecoderModel): config.vocab_size, config.hidden_size, dtype=config.torch_dtype, + enable_torch_compile_for_embedding=model_config. + enable_torch_compile_for_embedding, ) self.layers = nn.ModuleList([ diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index d051adf12b..41f870f890 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -63,8 +63,16 @@ class MLPLayer(MLP): layer_idx: int, ): config = model_config.pretrained_config + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + intermediate_size = config.intermediate_size[0] + else: + intermediate_size = config.intermediate_size[layer_idx] + else: + intermediate_size = config.intermediate_size + super().__init__(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, + intermediate_size=intermediate_size, bias=False, activation=relu2, dtype=config.torch_dtype, diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index b5ad4f4520..ee0263eb5e 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -1,19 +1,30 @@ # Plan for phi4-mm model support. # (done) step 1: support legacy inference pipeline for phi4-mm model. -# (todo) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522). +# (done) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522). +# (todo) step 3: optimization +# * use TRTLLM-attention to replace original pytorch attention in vision/audio encoders. +# * use data parallel to accelerate inference. import copy +import importlib +import os +import sys +from pathlib import Path from typing import List, Optional, Tuple import torch import transformers from PIL import Image +from tensorrt_llm.inputs.multimodal import MultimodalParams + from ...executor.request import LoRARequest -from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, +from ...inputs import (ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger -from ...lora_manager import LoraConfig +from ...lora_helper import LoraConfig from ...sampling_params import SamplingParams from ..attention_backend import AttentionMetadata from ..model_config import ModelConfig @@ -21,16 +32,361 @@ from .modeling_auto import AutoModelForCausalLM from .modeling_multimodal_utils import fuse_input_embeds from .modeling_utils import register_auto_model -# Special tokens -_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>' -_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' +# Special token ids from the original Phi-4-multimodal-instruct implementation +_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>' from HF `modeling_phi4mm.py` +_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' from HF `modeling_phi4mm.py` +_PAD_TOKEN_ID = 199999 # '<|endoftext|>' from HF `special_tokens_map.json` +_COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE = [-9999, + -1] # from HF `modeling_phi4mm.py` +_COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE = [float('-inf'), -10000 + ] # from HF `modeling_phi4mm.py` + +# Below classes will be loaded from HuggingFace codes, rather than using transformers version, +# since transformers version is not compatible with checkpoints and configs from `microsoft/Phi-4-multimodal-instruct`. +Phi4MMAudioEmbedding = None +Phi4MMImageEmbedding = None +Phi4MMConfig = None -# Create a PreTrainedModel class for transformers=4.53.1 upgrade. -# Core idea is to provide `prepare_inputs_for_generation` method from `GenerationMixin`. -class NewPreTrainedModel(transformers.modeling_utils.PreTrainedModel, - transformers.generation.GenerationMixin): - pass +# Make this a runtime lookup rather than a module-wide constant for easier unit testing. +def _is_disagg() -> bool: + return os.getenv("TLLM_MULTIMODAL_DISAGGREGATED", "0") == "1" + + +# Load the Phi4MM classes from HuggingFace Phi-4-multimodal-instruct repo. +# Remove this function by using the transformers version of Phi4Multimodal when weights/configs are converted to transformers format. +def _load_phi4mm_classes(local_path): + """Load Phi4MM classes from the specified local path.""" + global Phi4MMAudioEmbedding, Phi4MMImageEmbedding, Phi4MMConfig + if Phi4MMAudioEmbedding is not None and Phi4MMImageEmbedding is not None and Phi4MMConfig is not None: + return + + # Add parent folder to sys.path to enable relative import. + original_sys_path = sys.path.copy() + package_folder = Path(local_path) + parent_folder = str(package_folder.parent) + if parent_folder not in sys.path: + sys.path.insert(0, parent_folder) + + try: + # Import Phi4MMConfig from configuration_phi4mm.py. + config_path = os.path.join(local_path, 'configuration_phi4mm.py') + if not os.path.exists(config_path): + raise FileNotFoundError( + f"configuration_phi4mm.py not found at {local_path}.") + spec = importlib.util.spec_from_file_location("hf_config", config_path) + hf_config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(hf_config) + Phi4MMConfig = hf_config.Phi4MMConfig + + # Import Phi4MMAudioEmbedding and Phi4MMImageEmbedding from modeling_phi4mm.py. + modeling_phi4mm_path = os.path.join(local_path, 'modeling_phi4mm.py') + if not os.path.exists(modeling_phi4mm_path): + raise FileNotFoundError( + f"modeling_phi4mm.py not found at {local_path}.") + # `Phi-4-multimodal-instruct` as the package name to avoid relative import errors. + # `hf_modeling_phi4mm` as the module name to avoid name conflicts. + spec = importlib.util.spec_from_file_location( + "Phi-4-multimodal-instruct.hf_modeling_phi4mm", + modeling_phi4mm_path) + hf_modeling_phi4mm = importlib.util.module_from_spec(spec) + spec.loader.exec_module(hf_modeling_phi4mm) + Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding + Phi4MMImageEmbedding = hf_modeling_phi4mm.Phi4MMImageEmbedding + finally: + sys.path = original_sys_path + + +class HFPhi4MultimodalEncoder(transformers.PreTrainedModel, + transformers.generation.GenerationMixin): + + # Copy and modify from https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py::Phi4MMImageAudioEmbedding + # Note: the HF implementation here will cause duplicated encoders on all GPUs for TP>1 scenario. + # TODO: use TRTLLM-attention to replace original pytorch Flash_attn_2 in HFPhi4MultimodalEncoder. + config_class = Phi4MMConfig + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config: transformers.PretrainedConfig, **kwargs): + super().__init__(config, **kwargs) + self.padding_idx = config.pad_token_id + + self.embed_tokens = torch.nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx) + + self._attn_implementation = config._attn_implementation + + self.vocab_size = config.vocab_size + + embedding_config = { + 'embedding_cls': config.embd_layer['embedding_cls'], + **config.embd_layer + } + # The default values are from HuggingFace Phi-4-multimodal-instruct codes. + self.image_input_id = embedding_config.get('image_input_id', -1) + self.audio_input_id = embedding_config.get('audio_input_id', -10000) + if self.image_input_id == self.audio_input_id: + raise ValueError( + 'image_input_id and audio_input_id should be different') + + self.image_embd_layer_kwargs = embedding_config['image_embd_layer'] + self.image_embed = Phi4MMImageEmbedding(config, + **self.image_embd_layer_kwargs) + + self.audio_embd_layer_kwargs = embedding_config['audio_embd_layer'] + self.audio_embed = Phi4MMAudioEmbedding(config, + **self.audio_embd_layer_kwargs) + + def _replace_special_token_ids(self, + input_ids: torch.Tensor) -> torch.Tensor: + # Inplace-replacement for special token ids. + torch.where( + (input_ids >= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[0]) + & (input_ids <= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[1]), + torch.tensor(_IMAGE_SPECIAL_TOKEN_ID), + input_ids, + out=input_ids, + ) + torch.where( + (input_ids >= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[0]) + & (input_ids <= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[1]), + torch.tensor(_AUDIO_SPECIAL_TOKEN_ID), + input_ids, + out=input_ids, + ) + return input_ids + + def _batch_infer_image_embeds( + self, batched_input_ids: torch.Tensor, + multimodal_params: List[MultimodalParams]) -> torch.Tensor: + # Batch image inputs and attention mask with padding along dim=1 (patch num). + input_image_embeds_list, input_image_attn_mask_list, input_image_sizes_list = [], [], [] + for mm_param in multimodal_params: + mm_data = mm_param.multimodal_data + input_image_embeds = mm_data["input_image_embeds"] + if input_image_embeds is not None and input_image_embeds.numel( + ) > 0: + input_image_embeds_list.append(input_image_embeds) + input_image_attn_mask_list.append( + mm_data["image_attention_mask"]) + input_image_sizes_list.append(mm_data["image_sizes"]) + batched_image_hidden_states = None + if len(input_image_embeds_list) > 0: + # Padding image embeds/attn_masks along dim=1 (patch dimension). + b_list = [x.shape[0] for x in input_image_embeds_list] + p_list = [x.shape[1] for x in input_image_embeds_list] + c_i, h_i, w_i = input_image_embeds_list[0].shape[2:5] + h_i_attn, w_i_attn = input_image_attn_mask_list[0].shape[2:4] + total_b = sum(b_list) + max_p = max(p_list) + batched_image_embeds = torch.zeros( + (total_b, max_p, c_i, h_i, w_i), + dtype=input_image_embeds_list[0].dtype, + device=input_image_embeds_list[0].device) + batched_image_attn_mask = torch.zeros( + (total_b, max_p, h_i_attn, w_i_attn), + dtype=input_image_embeds_list[0].dtype, + device=input_image_embeds_list[0].device) + b_offset = 0 + for i, tensor in enumerate(input_image_embeds_list): + b, p = tensor.shape[:2] + batched_image_embeds[b_offset:b_offset + b, :p] = tensor + if input_image_attn_mask_list[i] is not None: + batched_image_attn_mask[ + b_offset:b_offset + + b, :p] = input_image_attn_mask_list[i] + else: + batched_image_attn_mask[b_offset:b_offset + b, :p] = 1 + b_offset += b + batched_image_sizes = torch.cat(input_image_sizes_list, dim=0) + # Forward image encoder with batched image embeds. + batched_image_hidden_states = self.image_embed( + input_ids=batched_input_ids, + input_embeds=batched_image_embeds, + image_sizes=batched_image_sizes, + image_attention_mask=batched_image_attn_mask, + wte=self.embed_tokens, + ) + return batched_image_hidden_states + + def _batch_infer_audio_embeds( + self, batched_input_ids: torch.Tensor, + multimodal_params: List[MultimodalParams]) -> torch.Tensor: + # Batch audio inputs and attention mask with padding along dim=1 (patch num). + input_audio_embeds_list, input_audio_attn_mask_list, input_audio_sizes_list = [], [], [] + for mm_param in multimodal_params: + mm_data = mm_param.multimodal_data + input_audio_embeds = mm_data["input_audio_embeds"] + if input_audio_embeds is not None and input_audio_embeds.numel( + ) > 0: + input_audio_embeds_list.append(input_audio_embeds) + input_audio_attn_mask_list.append( + mm_data["audio_attention_mask"]) + input_audio_sizes_list.append(mm_data["audio_embed_sizes"]) + batched_audio_hidden_states = None + if len(input_audio_embeds_list) > 0: + b_list = [x.shape[0] for x in input_audio_embeds_list] + p_list = [x.shape[1] for x in input_audio_embeds_list] + d_a = input_audio_embeds_list[0].shape[2] + total_b = sum(b_list) + max_p = max(p_list) + batched_audio_embeds = torch.zeros( + (total_b, max_p, d_a), + dtype=input_audio_embeds_list[0].dtype, + device=input_audio_embeds_list[0].device) + batched_audio_attn_mask = torch.zeros( + (total_b, max_p), + dtype=input_audio_embeds_list[0].dtype, + device=input_audio_embeds_list[0].device) + b_offset = 0 + for i, tensor in enumerate(input_audio_embeds_list): + b, p = tensor.shape[:2] + batched_audio_embeds[b_offset:b_offset + b, :p] = tensor + if input_audio_attn_mask_list[i] is not None: + batched_audio_attn_mask[ + b_offset:b_offset + + b, :p] = input_audio_attn_mask_list[i] + else: + batched_audio_attn_mask[b_offset:b_offset + b, :p] = 1 + b_offset += b + batched_audio_sizes = torch.cat(input_audio_sizes_list, dim=0) + # Forward audio encoder with batched audio embeds. + batched_audio_hidden_states = self.audio_embed( + input_ids=batched_input_ids, + input_embeds=batched_audio_embeds, + audio_embed_sizes=batched_audio_sizes, + audio_attention_mask=batched_audio_attn_mask, + wte=self.embed_tokens, + ) + return batched_audio_hidden_states + + def _encoding_per_request( + self, multimodal_params: List[MultimodalParams], + mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]: + # Loop implementation. + mm_embeddings = [] + for i in range(len(multimodal_params)): + input_ids = multimodal_params[i].multimodal_data["input_ids"] + input_image_embeds = multimodal_params[i].multimodal_data[ + "input_image_embeds"] + input_audio_embeds = multimodal_params[i].multimodal_data[ + "input_audio_embeds"] + image_sizes = multimodal_params[i].multimodal_data["image_sizes"] + image_attention_mask = multimodal_params[i].multimodal_data[ + "image_attention_mask"] + audio_embed_sizes = multimodal_params[i].multimodal_data[ + "audio_embed_sizes"] + audio_attention_mask = multimodal_params[i].multimodal_data[ + "audio_attention_mask"] + audio_projection_mode = multimodal_params[i].multimodal_data[ + "audio_projection_mode"] + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids = self._replace_special_token_ids(input_ids) + image_position_mask = input_ids == _IMAGE_SPECIAL_TOKEN_ID + non_image_position_mask = ~image_position_mask + + image_hidden_states = None + if input_image_embeds is not None: + image_hidden_states = self.image_embed( + input_ids=input_ids, + input_embeds=input_image_embeds, + image_sizes=image_sizes, + wte=self.embed_tokens, + image_attention_mask=image_attention_mask, + ) + audio_hidden_states = None + if input_audio_embeds is not None: + audio_hidden_states = self.audio_embed( + input_ids=input_ids, + input_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + wte=self.embed_tokens, + audio_projection_mode=audio_projection_mode, + ) + + if input_image_embeds is not None and input_audio_embeds is not None: + dtype = image_hidden_states.dtype + hidden_states = image_hidden_states * image_position_mask.to( + dtype).unsqueeze( + -1) + audio_hidden_states * non_image_position_mask.to( + dtype).unsqueeze(-1) + elif input_image_embeds is not None: + hidden_states = image_hidden_states + elif input_audio_embeds is not None: + hidden_states = audio_hidden_states + else: + hidden_states = self.embed_tokens(input_ids) + + # Postprocessing to get multimodal-only embeddings. + mm_token_mask = torch.isin(input_ids, mm_token_ids) + hidden_states = hidden_states[mm_token_mask] + + mm_embeddings.append(hidden_states) + return mm_embeddings + + def _encoding_batch_request( + self, multimodal_params: List[MultimodalParams], + mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]: + # Batch input_ids. + input_ids_list = [ + multimodal_params[i].multimodal_data["input_ids"] + for i in range(len(multimodal_params)) + ] + max_input_ids_len = max( + [input_ids.shape[1] for input_ids in input_ids_list]) + batched_input_ids = torch.full( + (len(multimodal_params), max_input_ids_len), + _PAD_TOKEN_ID, + device=input_ids_list[0].device) + for i, input_ids in enumerate(input_ids_list): + batched_input_ids[i, :input_ids.shape[1]] = input_ids + batched_input_ids = batched_input_ids.view(-1, max_input_ids_len) + batched_input_ids = self._replace_special_token_ids(batched_input_ids) + image_position_mask = batched_input_ids == _IMAGE_SPECIAL_TOKEN_ID + non_image_position_mask = ~image_position_mask + + # Batch inference for image and audio embeds. + batched_image_hidden_states = self._batch_infer_image_embeds( + batched_input_ids, multimodal_params) + batched_audio_hidden_states = self._batch_infer_audio_embeds( + batched_input_ids, multimodal_params) + + # Combine different modalities into one. + if batched_image_hidden_states is not None and batched_audio_hidden_states is not None: + batched_hidden_states = batched_image_hidden_states * image_position_mask.unsqueeze( + -1 + ) + batched_audio_hidden_states * non_image_position_mask.unsqueeze( + -1) + elif batched_image_hidden_states is not None: + batched_hidden_states = batched_image_hidden_states + elif batched_audio_hidden_states is not None: + batched_hidden_states = batched_audio_hidden_states + else: + batched_hidden_states = self.embed_tokens(batched_input_ids) + + # Postprocessing to get multimodal-only embeddings. + mm_token_mask = torch.isin(batched_input_ids, mm_token_ids) + batched_hidden_states = batched_hidden_states[mm_token_mask] + batched_hidden_states = [batched_hidden_states] + return batched_hidden_states + + @torch.inference_mode() + def forward(self, multimodal_params: List[MultimodalParams], + mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]: + if os.getenv("PHI4_MM_PER_REQUEST_INFER", "0") == "1": + # Reference code path to check correctness of batch inference and further dev. + # (TODO) Remove this path after accuracy bench and data parallelism are supported. + return self._encoding_per_request(multimodal_params, mm_token_ids) + else: + # Batch inference as default path. + return self._encoding_batch_request(multimodal_params, mm_token_ids) class Phi4MMInputProcessor(InputProcessor): @@ -40,10 +396,11 @@ class Phi4MMInputProcessor(InputProcessor): model_config: transformers.PretrainedConfig, tokenizer: transformers.AutoTokenizer, trust_remote_code: bool = True): - assert trust_remote_code, "trust_remote_code must be True for Phi4MM" + if not trust_remote_code: + raise ValueError("trust_remote_code must be True for Phi4MM") self.model_config = model_config - self.device = 'cuda' + self.device = 'cpu' self.tokenizer = tokenizer self.use_fast = True @@ -58,37 +415,18 @@ class Phi4MMInputProcessor(InputProcessor): trust_remote_code=trust_remote_code, use_fast=self.use_fast) - # Build pure-pytorch model architecture for multimodal encoder. - # Model weights are also loaded here. - OldPreTrainedModel = transformers.modeling_utils.PreTrainedModel - transformers.modeling_utils.PreTrainedModel = NewPreTrainedModel - # TODO: Make separate Phi4VisionEncoder and Phi4AudioEncoder, and move them to LLM-side. - ref_phi4mm_model = transformers.AutoModelForCausalLM.from_pretrained( - model_path, - trust_remote_code=True, - # Flash_attn_2 only supports bf16 or fp16 and set in HF config. - torch_dtype='auto', - _attn_implementation='flash_attention_2', - ).eval() - transformers.modeling_utils.PreTrainedModel = OldPreTrainedModel - self.phi4mm_modal_encoder = ref_phi4mm_model.model.embed_tokens_extend.to( - self.device) - # Required by Phi4MMImageAudioEmbedding. - # See link: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L701 - self.phi4mm_wte = ref_phi4mm_model.model.embed_tokens.to(self.device) - @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: - text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \ - inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {}) + text_prompt, mm_data = inputs.get("prompt"), inputs.get( + "multi_modal_data", {}) images = mm_data.get("image", None) audios = mm_data.get("audio", None) if images is not None: if isinstance(images[0], torch.Tensor): - # Convert normalized tensors (0-1) to PIL images (0-255). + # HF Phi4MM can only support PIL images. Convert normalized tensors (0-1) to PIL images (0-255). images = [ Image.fromarray((image.permute(1, 2, 0) * 255).to( torch.uint8).cpu().numpy()) for image in images @@ -109,43 +447,41 @@ class Phi4MMInputProcessor(InputProcessor): else: audio_projection_mode = 'speech' - # Processing with Phi4MMImageAudioEmbedding. - mm_features = self.phi4mm_modal_encoder( - input_ids=inputs['input_ids'], - input_embeds=None, - input_image_embeds=inputs['input_image_embeds'], - input_audio_embeds=inputs['input_audio_embeds'], - image_sizes=inputs['image_sizes'], - image_attention_mask=inputs['image_attention_mask'], - audio_embed_sizes=inputs['audio_embed_sizes'], - audio_attention_mask=inputs['audio_attention_mask'], - audio_projection_mode=audio_projection_mode, - wte=self.phi4mm_wte, - ) - - # Postprocessing to get multimodal-only embeddings. - image_token_mask = inputs['input_ids'] == _IMAGE_SPECIAL_TOKEN_ID - audio_token_mask = inputs['input_ids'] == _AUDIO_SPECIAL_TOKEN_ID - mm_token_mask = image_token_mask | audio_token_mask - mm_features = mm_features[mm_token_mask] - + # Will package inputs for language model forward in AGGREGATE mode. multimodal_data = {} - multimodal_data["multimodal_embedding"] = mm_features - + multimodal_data['input_ids'] = inputs['input_ids'] + multimodal_data['input_image_embeds'] = inputs['input_image_embeds'] + multimodal_data['image_sizes'] = inputs['image_sizes'] + multimodal_data['image_attention_mask'] = inputs['image_attention_mask'] + multimodal_data['input_audio_embeds'] = inputs['input_audio_embeds'] + multimodal_data['audio_embed_sizes'] = inputs['audio_embed_sizes'] + multimodal_data['audio_attention_mask'] = inputs['audio_attention_mask'] + multimodal_data['audio_projection_mode'] = audio_projection_mode return inputs['input_ids'][0].to(torch.int32).tolist(), { "multimodal_data": multimodal_data, } @register_auto_model("Phi4MMForCausalLM") -@register_input_processor(Phi4MMInputProcessor, model_type="phi4mm") +@register_input_processor( + Phi4MMInputProcessor, + model_type="phi4mm", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "<|image_{0}|>", + "audio": "<|audio_{0}|>", + }, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + placeholders_separator="", + )) class Phi4MMForCausalLM(transformers.PreTrainedModel): _supports_flash_attn_2 = True - MM_TOKEN_IDS = torch.tensor( - [_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID]) def __init__(self, model_config: ModelConfig): + if _is_disagg(): + raise ValueError( + "Phi4MM does not support disaggregated inference yet.") config = model_config.pretrained_config super().__init__(config) @@ -154,6 +490,15 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): if hasattr(self, "llm"): return + if not _is_disagg(): + _load_phi4mm_classes(config._name_or_path) + + # Setup HFPhi4MultimodalEncoder in AGGREGATE mode. + self.hf_phi4mm_model = HFPhi4MultimodalEncoder(config).eval() + self.hf_phi4mm_model.to(config.torch_dtype) + # Required by HFPhi4MultimodalEncoder. + self.phi4mm_wte = self.hf_phi4mm_model.embed_tokens + # We use Phi3ForCausalLM as the language model. llm_model_config = copy.deepcopy(model_config) llm_model_config.pretrained_config.architectures = ["Phi3ForCausalLM"] @@ -167,6 +512,18 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): self.is_loaded = True def load_weights(self, weights): + # Load weights into HFPhi4MultimodalEncoder. + if not _is_disagg(): + filtered_weights = {} + for k, v in weights.items(): + if k.startswith("model.embed_tokens."): + new_k = k.replace("model.embed_tokens.", "embed_tokens.") + filtered_weights[new_k] = v + elif k.startswith("model.embed_tokens_extend."): + new_k = k.replace("model.embed_tokens_extend.", "") + filtered_weights[new_k] = v + self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True) + # Filter out non-language model weights. weights = { k: v @@ -185,9 +542,13 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): else: updated_weights[k] = weights[k] weights = updated_weights - self.llm.load_weights(weights) + # Move mm_token_ids to the correct device. + self.mm_token_ids = torch.tensor( + [_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID], + device=self.device) + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() @@ -215,17 +576,24 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): ) multimodal_params = kwargs.get("multimodal_params", []) - mm_embeds = [] + mm_embedding = [] if len(multimodal_params) > 0: - mm_embeds = [ - multimodal_param.multimodal_data["multimodal_embedding"] - for multimodal_param in multimodal_params - ] + if not _is_disagg(): + # Forward the multimodal data to HFPhi4MultimodalEncoder in AGGREGATE mode. + mm_embedding = self.hf_phi4mm_model(multimodal_params, + self.mm_token_ids) + else: + # Directly fetch the multimodal embedding for DISAGG mode. + # This path is not functional now. `multimodal_params` will be prepared in PyExecutor. + mm_embedding = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] input_ids, input_embeds = fuse_input_embeds( self.llm.model.embed_tokens, input_ids, - mm_embeds, - mm_token_ids=self.MM_TOKEN_IDS, + mm_embedding, + mm_token_ids=self.mm_token_ids, ) output_prob = self.llm.forward( diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 03f15c37b4..67e329e326 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -12,7 +12,9 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams from ..._utils import nvtx_range_debug from ...functional import RopeEmbeddingUtils, RotaryScalingType -from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, +from ...inputs import (ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger from ...sampling_params import SamplingParams @@ -645,7 +647,16 @@ class Qwen2VLModelBase(PreTrainedModel): @register_auto_model("Qwen2VLForConditionalGeneration") -@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_vl") +@register_input_processor( + Qwen2VLInputProcessorBase, + model_type="qwen2_vl", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "<|vision_start|><|image_pad|><|vision_end|>", + "video": "<|vision_start|><|video_pad|><|vision_end|>" + }, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + )) class Qwen2VLModel(Qwen2VLModelBase): def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, @@ -657,7 +668,14 @@ class Qwen2VLModel(Qwen2VLModelBase): @register_auto_model("Qwen2_5_VLForConditionalGeneration") -@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_5_vl") +@register_input_processor( + Qwen2VLInputProcessorBase, + model_type="qwen2_5_vl", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "<|vision_start|><|image_pad|><|vision_end|>", + "video": "<|vision_start|><|video_pad|><|vision_end|>" + })) class Qwen2_5_VLModel(Qwen2VLModelBase): def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index 8635e510f4..48a73e85f1 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -1,8 +1,9 @@ +import math from typing import Optional, Tuple import torch from torch import nn -from transformers import Qwen3Config +from transformers import PretrainedConfig, Qwen3Config from tensorrt_llm.functional import PositionEmbeddingType @@ -21,6 +22,111 @@ from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import DecoderModel, register_auto_model +# Move out from this class +def compute_yarn_parameters( + config: PretrainedConfig, ) -> tuple[float, float, float, float]: + """ + Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1 + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + Returns: + factor: float, the scaling factor for the RoPE embeddings + low: float, the lower bound of the dimension range + high: float, the upper bound of the dimension range + attention_factor: float, the post-processing scaling factor applied to the computed cos/sin + """ + + # The config does not contain rope_scaling, which means the model is not using yarn + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None: + return 1.0, 0, 0, 1.0 + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr( + config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", + config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + factor = getattr(rope_scaling, "factor", 1.0) + attention_factor = rope_scaling.get("attention_factor") + mscale = rope_scaling.get("mscale") + mscale_all_dim = rope_scaling.get("mscale_all_dim") + + if "original_max_position_embeddings" in rope_scaling: + original_max_position_embeddings = rope_scaling[ + "original_max_position_embeddings"] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float( + get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = rope_scaling.get("beta_fast") or 32 + beta_slow = rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * + math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, + max_position_embeddings, truncate): + """Find dimension range bounds based on rotations""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + truncate = rope_scaling.get("truncate", True) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, + original_max_position_embeddings, + truncate) + + # These parts are implemented in the fusedQKNormRopeKernel.cu + # # def linear_ramp_factor(min, max, dim): + # # if min == max: + # # max += 0.001 # Prevent singularity + + # # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + # # ramp_func = torch.clamp(linear_func, 0, 1) + # # return ramp_func + + # # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # # to expand the possible context length. In other words, interpolation = apply scaling factor. + # # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) + # # inv_freq_extrapolation = 1.0 / pos_freqs + # # inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + # # # Get n-dimensional rotational scaling corrected for extrapolation + # # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + # # inv_freq = ( + # # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + # # + inv_freq_extrapolation * inv_freq_extrapolation_factor + # # ) + # # return inv_freq, attention_factor + return factor, low, high, attention_factor + + class Qwen3Attention(Attention): def __init__( @@ -30,11 +136,18 @@ class Qwen3Attention(Attention): fuse_qk_norm_rope: bool = True, ): config = model_config.pretrained_config + self.pretrained_config = config if getattr(config, "rope_scaling", None) is not None: + if "type" in config.rope_scaling: + pos_type = config.rope_scaling["type"] + elif "rope_type" in config.rope_scaling: + pos_type = config.rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have type or rope_type field") pos_embd_params = PositionalEmbeddingParams( - type=PositionEmbeddingType.from_string( - config.rope_scaling["type"]), + type=PositionEmbeddingType.from_string(pos_type), rope=RopeParams.from_config(config), ) else: @@ -92,12 +205,15 @@ class Qwen3Attention(Attention): return q, k def apply_qk_norm_rope(self, qkv, position_ids): + factor, low, high, attention_factor = compute_yarn_parameters( + self.pretrained_config) torch.ops.trtllm.fused_qk_norm_rope( qkv, self.num_heads, self.num_key_value_heads, self.num_key_value_heads, self.head_dim, self.q_norm.variance_epsilon, self.q_norm.weight, - self.k_norm.weight, self.pos_embd_params.rope.theta, - self.pos_embd_params.is_neox, position_ids.view(-1)) + self.k_norm.weight, + self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox, + position_ids.view(-1), factor, low, high, attention_factor) return qkv, None, None def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 2d447dd527..bd2ccfae0c 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -22,6 +22,7 @@ from ..modules.fused_moe import (BaseMoeRoutingMethod, from ..modules.linear import TensorParallelMode from ..modules.rms_norm import RMSNorm from ..speculative import SpecMetadata +from ..utils import AuxStreamType from .modeling_qwen3 import Qwen3Attention from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model @@ -107,7 +108,7 @@ class Qwen3MoE(nn.Module): routing_method=self.gate.routing_method, hidden_size=self.hidden_dim, intermediate_size=self.moe_intermediate_size, - aux_stream=aux_stream, + aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream}, dtype=config.torch_dtype, reduce_results=False, model_config=model_config, @@ -214,7 +215,9 @@ class Qwen3MoEDecoderLayer(DecoderLayer): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False # Self Attention hidden_states = self.self_attn( position_ids=position_ids, @@ -257,9 +260,6 @@ class Qwen3MoEDecoderLayer(DecoderLayer): if self.fusion_config.POST_MOE_FUSION: if do_finalize: - if spec_metadata: - spec_metadata.maybe_capture_hidden_states( - self.layer_idx, hidden_states, residual) hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( @@ -289,12 +289,8 @@ class Qwen3MoEDecoderLayer(DecoderLayer): hidden_states, residual = self.moe_allreduce( fc2_output, all_reduce_params=moe_all_reduce_params) - if spec_metadata: - spec_metadata.maybe_capture_hidden_states( - self.layer_idx, hidden_states, residual) - else: - if spec_metadata: + if spec_metadata and spec_metadata.is_layer_capture(self.layer_idx): spec_metadata.maybe_capture_hidden_states( self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: diff --git a/tensorrt_llm/_torch/models/modeling_qwen_moe.py b/tensorrt_llm/_torch/models/modeling_qwen_moe.py index 2bbf9b80d5..4fa3883246 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen_moe.py @@ -17,6 +17,7 @@ from ..modules.fused_moe import DefaultMoeRoutingMethod, create_moe from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear, TensorParallelMode from ..modules.rms_norm import RMSNorm +from ..utils import AuxStreamType from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, register_auto_model) @@ -53,7 +54,7 @@ class QwenMoE(nn.Module): routing_method=DefaultMoeRoutingMethod(top_k=self.top_k), hidden_size=self.hidden_dim, intermediate_size=self.moe_intermediate_size, - aux_stream=aux_stream, + aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream}, dtype=config.torch_dtype, reduce_results=reduce_results, model_config=model_config, diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index e8a5774211..f82c3b4de0 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Generic, Optional, Tuple import torch from torch import nn -from transformers import LlamaConfig +from transformers import LlamaConfig, PretrainedConfig from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper @@ -320,14 +320,45 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]): return hidden_states -def get_draft_model(model_config, draft_config): +class MTPForCausalLM(nn.Module): + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + start_layer_idx: int = 0, + lm_head: nn.Module = None, + model: nn.Module = None, + ): + super().__init__() + # Import here to avoid circular import + from .modeling_deepseekv3 import DeepseekV3MTP + + spec_dec_mode = model_config.spec_config.spec_dec_mode + assert spec_dec_mode.is_mtp() + mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle( + ) else model_config.spec_config.num_nextn_predict_layers + + self.mtp_layers = nn.ModuleList([ + DeepseekV3MTP(model_config, layer_idx + start_layer_idx, + model.aux_stream_dict) + for layer_idx in range(mtp_num_layers) + ]) + self.lm_head = lm_head + self.embed_tokens = model.embed_tokens + + +def get_draft_model(model_config, draft_config, lm_head, model): assert getattr(model_config, 'spec_config', None) != None spec_dec_mode = model_config.spec_config.spec_dec_mode if spec_dec_mode.is_eagle3_one_model(): return Eagle3ForCausalLM( draft_config, model_config.pretrained_config.num_hidden_layers) + elif spec_dec_mode.is_mtp(): + return MTPForCausalLM(model_config, + model_config.pretrained_config.num_hidden_layers, + lm_head, model) else: - raise NotImplemented( + raise NotImplementedError( f"get_draft_model does not support speculative decoding mode {spec_dec_mode}." ) @@ -341,27 +372,34 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig], hidden_size=model_config.pretrained_config.hidden_size, vocab_size=model_config.pretrained_config.vocab_size) self.draft_model = None - if getattr( - model_config, 'spec_config', None - ) and model_config.spec_config.spec_dec_mode.use_one_engine(): - draft_config = ModelConfig.from_pretrained( - model_config.spec_config.speculative_model_dir, - trust_remote_code=True, - attn_backend=model_config.attn_backend, - moe_backend=model_config.moe_backend, - mapping=model_config.mapping, - spec_config=model_config.spec_config, - max_num_tokens=model_config.max_num_tokens, - moe_max_num_tokens=model_config.moe_max_num_tokens) - - draft_config.quant_config.kv_cache_quant_algo = \ + spec_config = getattr(model_config, 'spec_config', None) + if spec_config and spec_config.spec_dec_mode.use_one_engine(): + draft_config = None + if spec_config.spec_dec_mode.is_eagle3_one_model(): + draft_config = ModelConfig.from_pretrained( + model_config.spec_config.speculative_model_dir, + trust_remote_code=True, + attn_backend=model_config.attn_backend, + moe_backend=model_config.moe_backend, + mapping=model_config.mapping, + spec_config=model_config.spec_config, + max_num_tokens=model_config.max_num_tokens, + moe_max_num_tokens=model_config.moe_max_num_tokens) + draft_config.quant_config.kv_cache_quant_algo = \ model_config.quant_config.kv_cache_quant_algo - self.draft_model = get_draft_model(model_config, draft_config) + self.draft_model = get_draft_model(model_config, draft_config, + self.lm_head, self.model) self.spec_worker = get_spec_worker(model_config.spec_config, model_config, model_config.mapping) + if draft_config is not None: + for key, value in draft_config.extra_attrs.items(): + assert key in ('attn_layers', 'mla_layers') + assert key in model_config.extra_attrs + model_config.extra_attrs[key].update(value) + def forward( self, attn_metadata: AttentionMetadata, diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 4252ecd4f5..238ac97ffe 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -421,8 +421,7 @@ class DecoderModelForCausalLM(nn.Module, self.model.__pp_init__() - def __post_init__(self): - # 1. mixed precision + def apply_layerwise_quant_config(self): quant_config_dict = self.model_config.quant_config_dict if quant_config_dict is not None: for name, module in self.named_modules(): @@ -466,11 +465,14 @@ class DecoderModelForCausalLM(nn.Module, module.quant_config = q break - # 2. skip quant for modules in QuantConfig.exclude_modules. - # kv_cache_quant_algo takes precedence over exclude_modules. - # kv_cache_quant_algo, if not None, is set for non-Attention - # modules too, which is the same practice as when there's no - # exclude_modules. + def apply_quant_config_exclude_modules(self): + """ + Skip quant for modules in QuantConfig.exclude_modules. + kv_cache_quant_algo takes precedence over exclude_modules. + kv_cache_quant_algo, if not None, is set for non-Attention + modules too, which is the same practice as when there's no + exclude_modules. + """ quant_config = self.model_config.quant_config kv_cache_quant_algo = None if quant_config: @@ -486,6 +488,10 @@ class DecoderModelForCausalLM(nn.Module, None) is not None: module.quant_config = new_config + def __post_init__(self): + self.apply_layerwise_quant_config() + self.apply_quant_config_exclude_modules() + for _, module in self.named_modules(): if callable(getattr(module, "create_weights", None)): module.create_weights() diff --git a/tensorrt_llm/_torch/models/modeling_vila.py b/tensorrt_llm/_torch/models/modeling_vila.py index 99820c1954..e69851cc2f 100644 --- a/tensorrt_llm/_torch/models/modeling_vila.py +++ b/tensorrt_llm/_torch/models/modeling_vila.py @@ -35,7 +35,9 @@ from transformers import (AutoConfig, AutoImageProcessor, AutoModel, PreTrainedModel) from ..._utils import nvtx_range -from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, +from ...inputs import (ExtraProcessedInputs, InputProcessor, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger from ...sampling_params import SamplingParams @@ -1118,7 +1120,16 @@ class VilaInputProcessor(InputProcessor): @register_auto_model(VilaConfig.model_architecture) -@register_input_processor(VilaInputProcessor, model_type="llava_llama") +@register_input_processor( + VilaInputProcessor, + model_type="llava_llama", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "<image>", + "video": "<vila/video>" + }, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + )) class VilaModel(PreTrainedModel): config_class = VilaConfig diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 4cc1e5712c..f9f1eaa1df 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -78,6 +78,7 @@ def attn_custom_op_inplace( mrope_position_deltas: Optional[torch.Tensor], attention_window_size: Optional[int], attention_mask_data: Optional[torch.Tensor], + attention_sinks: Optional[torch.Tensor], layer_idx: str, output: torch.Tensor, ) -> None: @@ -92,8 +93,9 @@ def attn_custom_op_inplace( mrope_position_deltas, attention_window_size, attention_mask_data, - False, - output=output) + enable_attn_nvfp4_output=False, + output=output, + attention_sinks=attention_sinks) class Attention(nn.Module): @@ -192,6 +194,8 @@ class Attention(nn.Module): gpus_per_node=config.mapping.gpus_per_node, enable_attention_dp=config.mapping.enable_attention_dp, ) + self.tp_size = tp_size + self.tp_rank = mapping.tp_rank assert self.num_heads % tp_size == 0 self.num_heads = self.num_heads // tp_size self.num_key_value_heads = (self.num_key_value_heads + tp_size - @@ -331,6 +335,7 @@ class Attention(nn.Module): enable_attn_nvfp4_output: bool = True, output: Optional[torch.Tensor] = None, output_sf: Optional[torch.Tensor] = None, + attention_sinks: Optional[torch.Tensor] = None, ): out_scale = None @@ -364,7 +369,8 @@ class Attention(nn.Module): attention_mask_data=attention_mask_data, enable_attn_nvfp4_output=enable_attn_nvfp4_output, output=output, - output_sf=output_sf) + output_sf=output_sf, + attention_sinks=attention_sinks) if isinstance(attn_output, tuple): assert len( attn_output @@ -372,6 +378,64 @@ class Attention(nn.Module): return attn_output[0], attn_output[1] return attn_output, None + def forward_impl( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + attention_mask: AttentionMask, + attention_window_size: Optional[int], + attention_mask_data: Optional[torch.Tensor], + mrope_config: Optional[dict], + attention_sinks: Optional[torch.Tensor] = None, + ): + mrope_rotary_cos_sin = None + mrope_position_deltas = None + if mrope_config is not None: + if "mrope_rotary_cos_sin" in mrope_config: + mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"] + if "mrope_position_deltas" in mrope_config: + mrope_position_deltas = mrope_config["mrope_position_deltas"] + + # Currently only TRTLLM and FLASHINFER are torch compile compatible backends. + # Only enable custom inplace op when torch compiling. + use_custom_inplace_op = (self.register_to_config + and (self.attn_backend == "TRTLLM" + or self.attn_backend == "FLASHINFER") + and is_torch_compiling()) + + if use_custom_inplace_op: + output = self.create_output(q) + attn_custom_op_inplace( + q, + k, + v, + attention_mask, + mrope_rotary_cos_sin, + mrope_position_deltas, + attention_window_size, + attention_mask_data, + attention_sinks, + self.layer_idx_str, + output, + ) + else: + output, output_sf = self._attn_impl(q, + k, + v, + attn_metadata, + attention_mask, + mrope_rotary_cos_sin, + mrope_position_deltas, + attention_window_size, + attention_mask_data, + attention_sinks=attention_sinks) + if output_sf is not None: + output = Fp4QuantizedTensor(output, output_sf) + + return output + def forward( self, position_ids: Optional[torch.IntTensor], @@ -383,6 +447,7 @@ class Attention(nn.Module): lora_params: Optional[dict] = None, attention_window_size: Optional[int] = None, attention_mask_data: Optional[torch.Tensor] = None, + attention_sinks: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -414,54 +479,22 @@ class Attention(nn.Module): if qkv_lora is not None: qkv = qkv + qkv_lora - mrope_rotary_cos_sin = None - mrope_position_deltas = None - if mrope_config is not None: - if "mrope_rotary_cos_sin" in mrope_config: - mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"] - if "mrope_position_deltas" in mrope_config: - mrope_position_deltas = mrope_config["mrope_position_deltas"] - - output = None - q, k, v = qkv, None, None q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) - # Currently only TRTLLM and FLASHINFER are torch compile compatible backends. - # Only enable custom inplace op when torch compiling. - use_custom_inplace_op = (self.register_to_config - and (self.attn_backend == "TRTLLM" - or self.attn_backend == "FLASHINFER") - and is_torch_compiling()) - if use_custom_inplace_op: - output = self.create_output(q) - attn_custom_op_inplace( - q, - k, - v, - attention_mask, - mrope_rotary_cos_sin, - mrope_position_deltas, - attention_window_size, - attention_mask_data, - self.layer_idx_str, - output=output, - ) - else: - output, output_sf = self._attn_impl( - q, - k, - v, - attn_metadata, - attention_mask, - mrope_rotary_cos_sin, - mrope_position_deltas, - attention_window_size, - attention_mask_data, - ) - if output_sf is not None: - output = Fp4QuantizedTensor(output, output_sf) + if attention_sinks is not None: + assert self.attn_backend == "TRTLLM", "Attention sinks are only supported for TRTLLM backend." + + output = self.forward_impl(q, + k, + v, + attn_metadata, + attention_mask, + attention_window_size, + attention_mask_data, + mrope_config=mrope_config, + attention_sinks=attention_sinks) attn_output = self.o_proj(output, all_reduce_params=all_reduce_params, @@ -484,12 +517,17 @@ class Attention(nn.Module): Returns: tuple: A tuple of (q, k, v). """ - q, k, v = self.split_qkv(q, k, v) # If RoPE is fused into the attention OP, do not apply RoPE here. if not self.rope_fusion and position_ids is not None: + q, k, v = self.split_qkv(q, k, v) q, k = self.rotary_emb(position_ids, [q, k]) return q, k, v + def apply_qk_norm(self, q, k): + raise NotImplementedError( + f"QK norm is not implemented for {self.__class__.__name__}." + "Please override the `apply_qk_norm` method in the subclass.") + @torch.library.custom_op("trtllm::mla_custom_op_inplace", mutates_args=("output", )) diff --git a/tensorrt_llm/_torch/modules/embedding.py b/tensorrt_llm/_torch/modules/embedding.py index 7f92cdfd6a..7fee9b515d 100644 --- a/tensorrt_llm/_torch/modules/embedding.py +++ b/tensorrt_llm/_torch/modules/embedding.py @@ -126,8 +126,6 @@ def get_masked_input_and_mask( return input_, ~vocab_mask.unsqueeze(-1) -# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module. -@torch.compile(options={"max-autotune": True}) def pre_comm_embedding_ops( input_: torch.Tensor, weight: torch.Tensor, @@ -184,6 +182,7 @@ class Embedding(LMHead): mapping: Optional[Mapping] = None, tensor_parallel_mode: Optional[TensorParallelMode] = None, gather_output: bool = False, + enable_torch_compile_for_embedding: Optional[bool] = False, ): super().__init__( embedding_dim=embedding_dim, @@ -193,6 +192,9 @@ class Embedding(LMHead): tensor_parallel_mode=tensor_parallel_mode, gather_output=gather_output, ) + + self.enable_torch_compile_for_embedding = enable_torch_compile_for_embedding + if self.tp_size > 1: slice_width = math.ceil(num_embeddings / self.tp_size) self.vocab_start_index = self.tp_rank * slice_width @@ -204,11 +206,16 @@ class Embedding(LMHead): def forward(self, input): # Run the ops before all_reduce/all_gather. - output = pre_comm_embedding_ops(input, self.weight, self.tp_size, - self.tp_rank, self.tp_mode, - self.vocab_start_index, - self.vocab_end_index, - self.gather_output, self.padding_size) + # We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module. + embedding_ops_func = torch.compile( + pre_comm_embedding_ops, + options={"max-autotune": True}, + disable=not self.enable_torch_compile_for_embedding) + output = embedding_ops_func(input, self.weight, self.tp_size, + self.tp_rank, self.tp_mode, + self.vocab_start_index, + self.vocab_end_index, self.gather_output, + self.padding_size) # Run the all_reduce/all_gather. if self.tp_size > 1: diff --git a/tensorrt_llm/_torch/modules/fused_moe/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/__init__.py index a385d24cdd..053ecaa25f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/__init__.py @@ -1,6 +1,7 @@ from .create_moe import create_moe, get_moe_cls from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE +from .fused_moe_triton import TritonFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE from .fused_moe_wide_ep import WideEPMoE @@ -13,10 +14,12 @@ from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, Llama4RenormalizeMoeRoutingMethod, LoadBalancedMoeRoutingMethod, RenormalizeMoeRoutingMethod, RenormalizeNaiveMoeRoutingMethod, RoutingMethodType, - SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod) + SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod, + create_renormalize_expert_load_balanced_logits) __all__ = [ "BaseMoeRoutingMethod", + "create_renormalize_expert_load_balanced_logits", "create_moe", "CuteDslFusedMoE", "CutlassFusedMoE", @@ -35,6 +38,7 @@ __all__ = [ "RoutingMethodType", "SparseMixerMoeRoutingMethod", "StaticMoeRoutingMethod", + "TritonFusedMoE", "TRTLLMGenFusedMoE", "VanillaMoE", "WideEPMoE", diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 0b47e18f60..74f56ee5d6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Dict, Optional, Type import torch @@ -6,9 +6,11 @@ from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig from ...model_config import ModelConfig +from ...utils import AuxStreamType from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE from .fused_moe_deepgemm import DeepGemmFusedMoE +from .fused_moe_triton import TritonFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE from .fused_moe_wide_ep import WideEPMoE @@ -37,16 +39,21 @@ def get_moe_cls( elif moe_backend.upper() == "TRTLLM": if quant_config is not None and ( quant_config.quant_mode.has_fp8_block_scales() - or quant_config.quant_mode.has_nvfp4()): + or quant_config.quant_mode.has_nvfp4() + or quant_config.quant_mode.has_w4a16_mxfp4() + or quant_config.quant_mode.has_w4a8_mxfp4_fp8() + or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()): return TRTLLMGenFusedMoE else: logger.warning( - "TRTLLMGenFusedMoE only supports fp8_block_scales or nvfp4. " + "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8. " f"Check out details in quant_config: {quant_config}" "Using CutlassFusedMoE instead.") return CutlassFusedMoE elif moe_backend.upper() == "WIDEEP": return WideEPMoE + elif moe_backend.upper() == "TRITON": + return TritonFusedMoE else: raise ValueError(f"Unsupported moe backend: {moe_backend}") @@ -60,10 +67,14 @@ def create_moe( reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), override_quant_config: Optional[QuantConfig] = None, - aux_stream: Optional[torch.cuda.Stream] = None, + aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, + bias: bool = False, apply_router_weight_on_input: bool = False, layer_idx: Optional[int] = None, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, ) -> MoE: moe_cls = get_moe_cls(model_config, routing_method, dtype, override_quant_config) @@ -72,6 +83,20 @@ def create_moe( if moe_load_balancer is not None: assert moe_cls == WideEPMoE, "MoE Load Balance is only supported in WideEPMoE now." + if bias: + assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE + ], f"bias not supported in {moe_cls.__name__}." + + if swiglu_alpha is not None or swiglu_beta is not None: + assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE], \ + f"swiglu_alpha and swiglu_beta are only supported in CutlassFusedMoE, TritonFusedMoE and TRTLLMGenFusedMoE, not in {moe_cls.__name__}." + assert swiglu_alpha is not None and swiglu_beta is not None, \ + "Both swiglu_alpha and swiglu_beta must be provided." + + if swiglu_limit is not None: + assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE], \ + f"swiglu_limit is only supported in CutlassFusedMoE, TritonFusedMoE and TRTLLMGenFusedMoE, not in {moe_cls.__name__}." + if moe_cls == TRTLLMGenFusedMoE: assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE." @@ -84,7 +109,11 @@ def create_moe( reduce_results=reduce_results, model_config=model_config, weight_loading_mode=weight_loading_mode, + bias=bias, layer_idx=layer_idx, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, ) elif moe_cls == CutlassFusedMoE: return moe_cls( @@ -95,10 +124,14 @@ def create_moe( dtype=dtype, reduce_results=reduce_results, model_config=model_config, - aux_stream=aux_stream, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, + bias=bias, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, ) elif moe_cls == WideEPMoE: return moe_cls( @@ -109,7 +142,7 @@ def create_moe( dtype=dtype, reduce_results=reduce_results, model_config=model_config, - aux_stream=aux_stream, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, @@ -137,7 +170,7 @@ def create_moe( dtype=dtype, reduce_results=reduce_results, model_config=model_config, - aux_stream=aux_stream, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, @@ -151,10 +184,28 @@ def create_moe( dtype=dtype, reduce_results=reduce_results, model_config=model_config, - aux_stream=aux_stream, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, ) + elif moe_cls == TritonFusedMoE: + assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TritonFusedMoE." + + return moe_cls( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + weight_loading_mode=weight_loading_mode, + bias=bias, + layer_idx=layer_idx, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, + ) else: raise ValueError(f"Unsupported moe backend: {moe_cls}") diff --git a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py index 5ad3702481..c852bdb929 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py +++ b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py @@ -154,24 +154,6 @@ class VariableLengthLowLatencyBuffer: # Later, you can use our GEMM library to do the computation with this specific format return recv_hidden_states, recv_expert_count, handle - def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor, - scales: torch.Tensor, topk_idx: torch.Tensor, - num_max_dispatch_tokens_per_rank: int, - num_experts: int): - assert num_experts == self.num_experts - - # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) - recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \ - self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts) - assert event.event is None - assert hook is None - - # NOTES: the actual tensor will not be received only if you call `hook()`, - # it is useful for double-batch overlapping, but **without any SM occupation** - # If you don't want to overlap, please set `return_recv_hook=False` - # Later, you can use our GEMM library to do the computation with this specific format - return recv_hidden_states, recv_scales, recv_expert_count, handle - def low_latency_combine(self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple): @@ -184,6 +166,30 @@ class VariableLengthLowLatencyBuffer: # NOTES: the same behavior as described in the dispatch kernel return combined_hidden_states + def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor, + scales: torch.Tensor, topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int): + assert num_experts == self.num_experts + + recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \ + self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts) + assert event.event is None + assert hook is None + + return recv_hidden_states, recv_scales, recv_expert_count, handle + + def low_latency_combine_fp4(self, hidden_states: torch.Tensor, + global_scales: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, handle: Tuple): + combined_hidden_states, event, hook = \ + self.buffer.low_latency_combine_fp4(hidden_states, global_scales, topk_idx, topk_weights, handle) + assert event.event is None + assert hook is None + + return combined_hidden_states + def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None: self.buffer.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 815dae6476..572ed0e061 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -7,7 +7,7 @@ import torch.nn.functional as F from tensorrt_llm._utils import get_sm_version from ...model_config import ModelConfig -from ...utils import Fp4QuantizedTensor +from ...utils import AuxStreamType, Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE from .quantization import MoEWeightLoadingMode from .routing import BaseMoeRoutingMethod @@ -97,7 +97,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): top_k (int): Number of top experts to select for each input token. hidden_size (int): Size of the hidden state. intermediate_size (int): Size of the intermediate state. - aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. @@ -118,7 +118,8 @@ class CuteDslFusedMoE(CutlassFusedMoE): dtype: Optional[torch.dtype] = None, reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), - aux_stream: Optional[torch.cuda.Stream] = None, + aux_stream_dict: Optional[Dict[AuxStreamType, + torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, apply_router_weight_on_input: bool = False, @@ -133,7 +134,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): dtype=dtype, reduce_results=reduce_results, model_config=model_config, - aux_stream=aux_stream, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 025b112034..34bb61a7ab 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -1,15 +1,26 @@ +import os +from functools import cached_property from typing import Dict, List, Optional, Union import torch -from ...distributed import allgather, reducescatter +from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe +from tensorrt_llm.math_utils import pad_up + +from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import EventType, Fp4QuantizedTensor, ceil_div, swizzle_sf +from ...utils import (AuxStreamType, EventType, Fp4QuantizedTensor, ceil_div, + swizzle_sf) from .interface import MoE -from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, - FP8QDQFusedMoEMethod, MoEWeightLoadingMode, - NVFP4CutlassFusedMoEMethod, - UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod) + +# isort: off +from .quantization import ( + DeepSeekFP8BlockScalesFusedMoEMethod, FP8QDQFusedMoEMethod, + MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod, + INT8WoqPerChannelFusedMoEMethod, W4A8MXFP4FP8CutlassFusedMoEMethod, + W4A8MXFP4MXFP8CutlassFusedMoEMethod, WFP4A16FusedMoEMethod, + WInt4AFP8FusedMoEMethod) +# isort: on from .routing import BaseMoeRoutingMethod @@ -22,7 +33,7 @@ class CutlassFusedMoE(MoE): top_k (int): Number of top experts to select for each input token. hidden_size (int): Size of the hidden state. intermediate_size (int): Size of the intermediate state. - aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. @@ -51,11 +62,16 @@ class CutlassFusedMoE(MoE): dtype: Optional[torch.dtype] = None, reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), - aux_stream: Optional[torch.cuda.Stream] = None, + aux_stream_dict: Optional[Dict[AuxStreamType, + torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, + bias: bool = False, apply_router_weight_on_input: bool = False, layer_idx: Optional[int] = None, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, ): super().__init__( @@ -67,8 +83,18 @@ class CutlassFusedMoE(MoE): reduce_results=reduce_results, model_config=model_config, weight_loading_mode=weight_loading_mode, + bias=bias, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, ) + if model_config.quant_config and model_config.quant_config.layer_quant_mode.has_w4a16_mxfp4( + ): + self.hidden_size = ((self.hidden_size + 127) // 128) * 128 + self.intermediate_size_per_partition = ( + (self.intermediate_size_per_partition + 127) // 128) * 128 + self.layer_idx = layer_idx self.num_slots = self.num_experts @@ -85,15 +111,15 @@ class CutlassFusedMoE(MoE): assert len( self.initial_local_expert_ids) == self.expert_size_per_partition - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - if self.use_dp: - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < max_num_tokens: - self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream( - ) + if self.moe_max_num_tokens < moe_max_num_tokens: + self.aux_stream = aux_stream_dict[ + AuxStreamType. + MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( + ) self.event_dict = { key: torch.cuda.Event() for key in [EventType.Main, EventType.MoeChunkingOverlap] @@ -112,9 +138,21 @@ class CutlassFusedMoE(MoE): self.has_been_profiled = False self.has_been_profiled_min_latency = False + # TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future. + self.alltoall_workspace = None + self.alltoall_prepare_workspace = None + if self.enable_alltoall: + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) + # If True, the router weight will be multiplied on the input rather than at the end of FC2 self.apply_router_weight_on_input = apply_router_weight_on_input + self.use_fused_finalize = not model_config.moe_disable_finalize_fusion + self._weights_created = False if not model_config.skip_create_weights_in_init: self.create_weights() @@ -130,8 +168,10 @@ class CutlassFusedMoE(MoE): if not (self.quant_config.quant_mode.has_nvfp4() | self.quant_config.quant_mode.has_fp8_block_scales() | self.quant_config.quant_mode.has_fp8_qdq() - | self.quant_config.quant_mode. - is_int4_weight_only_per_group()): + | self.quant_config.quant_mode.is_weight_only() + | self.quant_config.quant_mode.has_w4a8_mxfp4_fp8() + | self.quant_config.quant_mode.has_w4a16_mxfp4() + | self.quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()): raise ValueError( f"unsupported quantization mode: {self.quant_config.quant_mode}" ) @@ -142,6 +182,21 @@ class CutlassFusedMoE(MoE): return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group( ) + @property + def has_int8_woq_per_channel(self): + return self.quant_config.layer_quant_mode.is_int8_weight_only( + ) and not self.quant_config.layer_quant_mode.has_per_group_scaling() + + @cached_property + def enable_alltoall(self): + return (self.mapping.moe_ep_size > self.routing_method.experts_per_token + and self.routing_method.experts_per_token % 4 == + 0 # alltoall without allgather only supports top_k % 4 == 0 + and self.mapping.enable_attention_dp + and self.mapping.tp_size > 1 + and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1" + and MnnvlMemory.supports_mnnvl()) + def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -154,6 +209,14 @@ class CutlassFusedMoE(MoE): elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( ): return WInt4AFP8FusedMoEMethod() + elif self.has_int8_woq_per_channel: + return INT8WoqPerChannelFusedMoEMethod() + elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8(): + return W4A8MXFP4FP8CutlassFusedMoEMethod() + elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4(): + return WFP4A16FusedMoEMethod() + elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8(): + return W4A8MXFP4MXFP8CutlassFusedMoEMethod() else: raise ValueError( f"Unsupported quantization mode: {self.quant_config.quant_mode}" @@ -171,24 +234,6 @@ class CutlassFusedMoE(MoE): self._weights_created = True self._check_configs() - def reducescatter_or_allreduce( - self, - inputs, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - ): - outputs = inputs - if self.parallel_size > 1: - if self.use_dp: - outputs = reducescatter( - inputs, - self.mapping, - dim=0, - sizes=None if use_dp_padding else all_rank_num_tokens) - elif self.reduce_results: - outputs = self.all_reduce(inputs) - return outputs - def forward_chunk( self, x: Union[torch.Tensor, Fp4QuantizedTensor], @@ -222,20 +267,32 @@ class CutlassFusedMoE(MoE): run_post_quant_allgather = self.use_dp and self.parallel_size > 1 # quantize inputs use_deepseek_fp8_block_scale = False - use_w4a8_group_scaling = False + use_w4_group_scaling = False + use_int8_woq_per_channel = False + use_mxfp8_act_scaling = False weight_dtype = self.w3_w1_weight.dtype x_sf = None + x_row = x.shape[0] + x_col = x.shape[1] if self.has_any_quant: - if self.has_fp8_qdq: + if self.has_fp8_qdq or self.has_w4a8_mxfp4_fp8: x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( x, self.fc31_input_dequant) elif self.has_deepseek_fp8_block_scales: use_deepseek_fp8_block_scale = True elif self.has_w4afp8: - use_w4a8_group_scaling = True + use_w4_group_scaling = True weight_dtype = torch.quint4x2 + elif self.has_w4a16_mxfp4: + pad_size = self.hidden_size - x.shape[1] + original_hidden_size = x.shape[1] + x = torch.nn.functional.pad(x, (0, pad_size)) + use_w4_group_scaling = True + weight_dtype = torch.uint8 + elif self.has_int8_woq_per_channel: + use_int8_woq_per_channel = True elif self.has_nvfp4: - if run_post_quant_allgather: + if run_post_quant_allgather or self.enable_alltoall: if isinstance(x, Fp4QuantizedTensor): assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication" x_row = x.shape[0] @@ -253,13 +310,86 @@ class CutlassFusedMoE(MoE): x, x_sf = torch.ops.trtllm.fp4_quantize( x, self.fc31_input_scale, self.scaling_vector_size, False, True) + elif self.has_w4a8_mxfp4_mxfp8: + use_mxfp8_act_scaling = True + if run_post_quant_allgather or self.enable_alltoall: + x, x_sf = torch.ops.trtllm.mxfp8_quantize( + x, False, alignment=self.quant_method.weight_alignment) + else: + x, x_sf = torch.ops.trtllm.mxfp8_quantize( + x, True, alignment=self.quant_method.weight_alignment) + # Update x_row and x_col to the padded shape + x_row, x_col = x.shape[0], x.shape[1] else: raise ValueError( f"unsupported quantization mode: {self.quant_config.quant_mode}" ) - # gather inputs for attention dp - if run_post_quant_allgather: + # Prepare additional information for profiling in case padding is applied when using alltoall. + # Only the non-alltoall case is considered for profiling in the warmup phase. + # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall. + if self.enable_alltoall: + if all_rank_num_tokens is not None: + tuner_num_tokens = sum(all_rank_num_tokens) + else: + tuner_num_tokens = x.shape[0] * self.mapping.tp_size + tuner_top_k = token_selected_experts.shape[1] + else: + tuner_num_tokens = None + tuner_top_k = None + + # Alltoall or allgather for attention DP + token_count = x.shape[0] + alltoall_info = None # Store for later combine + if self.enable_alltoall: + assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall" + # Prepare alltoall indices + top_k = self.routing_method.experts_per_token + max_num_token = max( + all_rank_num_tokens) if all_rank_num_tokens else token_count + + # Handle case where token_final_scales might be None (when apply_router_weight_on_input=True) + if token_final_scales is None: + token_final_scales = torch.ones_like(token_selected_experts, + dtype=torch.float32) + + # TODO: support alltoall without allgather for top_k % 4 != 0 + assert top_k % 4 == 0, "alltoall without allgather only supports top_k % 4 == 0" + assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" + alltoall_info, token_selected_experts, token_final_scales, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + token_selected_experts, token_final_scales, None, + self.alltoall_prepare_workspace, max_num_token, self.ep_rank, + self.ep_size, self.num_experts, self.num_experts, top_k) + + # Dispatch alltoall (common for both paths) + x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info, + self.alltoall_workspace, + self.ep_rank, self.ep_size) + if x_sf is not None: + x_sf = x_sf.view(x_row, ceil_div(x_col, + self.scaling_vector_size)) + + # Pad dim[1] to 16 bytes alignment for alltoall + # TODO: Remove this padding if possible + sf_per_16bytes = 16 // x_sf.element_size() + x_sf_col_orig = x_sf.shape[1] + x_sf_col = pad_up(x_sf_col_orig, sf_per_16bytes) + if x_sf_col > x_sf_col_orig: + x_sf = torch.nn.functional.pad( + x_sf, (0, x_sf_col - x_sf_col_orig)) + + x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info, + self.alltoall_workspace, + self.ep_rank, self.ep_size) + x_row = x_sf.shape[0] + + # TODO: Remove this slicing required by padding if possible + x_sf = x_sf[:, :x_sf_col_orig].contiguous() + + x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) + + elif run_post_quant_allgather: + # Original allgather logic if x_sf is not None: x_sf = x_sf.view(x_row, ceil_div(x_col, self.scaling_vector_size)) @@ -281,12 +411,16 @@ class CutlassFusedMoE(MoE): token_selected_experts, token_final_scales, self.w3_w1_weight.view(weight_dtype), - None, # fc1_expert_biases + self.w3_w1_bias, self.w2_weight.view(weight_dtype), - None, # fc2_expert_biases + self.w2_bias, output_dtype, quant_scales=self.quant_scales, input_sf=x_sf, + swizzled_input_sf=True, + swiglu_alpha=self.swiglu_alpha, + swiglu_beta=self.swiglu_beta, + swiglu_limit=self.swiglu_limit, tp_size=self.tp_size, tp_rank=self.tp_rank, ep_size=self.ep_size, @@ -295,15 +429,39 @@ class CutlassFusedMoE(MoE): cluster_rank=self.cluster_rank, enable_alltoall=self.enable_alltoall, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - use_w4a8_group_scaling=use_w4a8_group_scaling, + use_w4_group_scaling=use_w4_group_scaling, + use_int8_woq_per_channel=use_int8_woq_per_channel, + use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=False, + use_fused_finalize=self.use_fused_finalize, tune_max_num_tokens=self.tune_max_num_tokens, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, ) - # Custom op requires all inputs are in the same type. # Only in cutlass_min_latency_mode, the output is a list of tensors. # Otherwise, the output should be unpacked as a single tensor. final_hidden_states = final_hidden_states[0] + # TODO: Fuse this for padded MXFP4. + final_hidden_states = final_hidden_states[:, :self. + hidden_size].contiguous() + + if self.has_w4a16_mxfp4: + final_hidden_states = final_hidden_states[:, : + original_hidden_size].contiguous( + ) + + # Combine results if using alltoall + if self.enable_alltoall and alltoall_info is not None: + top_k = self.routing_method.experts_per_token + final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine( + final_hidden_states, + alltoall_info, + self.alltoall_workspace, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + top_k=top_k, + token_count=token_count) return final_hidden_states @@ -325,7 +483,7 @@ class CutlassFusedMoE(MoE): use_dp_padding: Optional[bool] = None, ) -> torch.Tensor: assert do_finalize, "CutlassFusedMoE does not support do_finalize=False" - if self.use_dp: + if self.use_dp and self.parallel_size > 1: assert all_rank_num_tokens is not None assert use_dp_padding is not None num_rows = sum(all_rank_num_tokens) @@ -420,7 +578,7 @@ class CutlassFusedMoE(MoE): outputs = torch.cat(outputs_list) - if self.use_dp: + if self.use_dp and self.parallel_size > 1: rank = self.mapping.tp_rank outputs = outputs[:all_rank_num_tokens[rank]] return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index bcdf8d4415..a5ca05694b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -11,9 +11,10 @@ from tensorrt_llm._utils import nvtx_range from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import Fp4QuantizedTensor +from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE -from .quantization import MoEWeightLoadingMode +from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, + MoEWeightLoadingMode, UnquantizedFusedMoEMethod) from .routing import BaseMoeRoutingMethod @@ -87,6 +88,7 @@ def _masked_index_copy_group_quant_fp8( def masked_index_copy_group_quant_fp8( output: torch.Tensor, + output_s: torch.Tensor, input: torch.Tensor, start_offsets: torch.Tensor, row_indices: torch.Tensor, @@ -107,14 +109,10 @@ def masked_index_copy_group_quant_fp8( col_size = output.shape[1] dim_size = output.shape[2] - # create padded output_s alignment = 4 scale_dim = (dim_size + group_size - 1) // group_size padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment padded_col_size = (col_size + alignment - 1) // alignment * alignment - output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size), - dtype=torch.int32, - device='cuda') # get block/grid/stage/warp num_groups = (dim_size + group_size - 1) // group_size @@ -246,6 +244,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor, @nvtx_range("[DG]") def deepgemm_fp8_group_blockwise_gemm( + d: torch.Tensor, a: torch.Tensor, b: torch.Tensor, sfa: torch.Tensor, @@ -253,10 +252,6 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m: torch.Tensor, expected_m: int, ) -> torch.Tensor: - d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), - device=b.device, - dtype=torch.bfloat16) - # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` assert a.stride(-1) == 1 assert b.stride(-1) == 1 @@ -286,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m, expected_m, disable_ue8m0_cast=True) - return d + return + + +def set_strides(workspace: torch.Tensor, g: int, m: int, k: int): + workspace = workspace[0:g * m * k] + workspace = workspace.as_strided( + size=(g, m, k), + stride=(m * k, k, 1), + ) + return workspace class DeepGemmFusedMoE(CutlassFusedMoE): @@ -298,7 +302,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE): top_k (int): Number of top experts to select for each input token. hidden_size (int): Size of the hidden state. intermediate_size (int): Size of the intermediate state. - aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. @@ -319,12 +323,25 @@ class DeepGemmFusedMoE(CutlassFusedMoE): dtype: Optional[torch.dtype] = None, reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), - aux_stream: Optional[torch.cuda.Stream] = None, + aux_stream_dict: Optional[Dict[AuxStreamType, + torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, apply_router_weight_on_input: bool = False, layer_idx: Optional[int] = None, ): + if model_config.moe_max_num_tokens is None: + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + # The default moe_max_num_tokens is calculated from the following formula: + # max_isl = 8196, max_batch_size = 1024, mtp = 0 + # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344 + # moe_max_num_tokens = max_num_tokens * 2 = 18688 + # It can avoid OOM for 8k/1k cases. + default_moe_max_num_tokens = 18688 + if moe_max_num_tokens > default_moe_max_num_tokens: + model_config._frozen = False + model_config.moe_max_num_tokens = default_moe_max_num_tokens + model_config._frozen = True super().__init__( routing_method=routing_method, @@ -334,12 +351,55 @@ class DeepGemmFusedMoE(CutlassFusedMoE): dtype=dtype, reduce_results=reduce_results, model_config=model_config, - aux_stream=aux_stream, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, ) + def get_workspace(self, m_max: int, group_size: int): + hidden_size = self.hidden_size + intermediate_size = self.intermediate_size + num_experts = self.expert_size_per_partition + + # create workspace + fp8_dim = max(hidden_size, intermediate_size) + workspace_0 = torch.empty((num_experts * m_max * fp8_dim), + dtype=torch.float8_e4m3fn, + device='cuda') + workspace_1 = torch.empty( + (num_experts * m_max * max(intermediate_size * 2, hidden_size)), + dtype=torch.bfloat16, + device='cuda') + + # create workspace for scaling factors + m_padded = fp8_utils.align(m_max, 4) + scale_k = fp8_utils.ceil_div(fp8_dim, group_size) + scale_k_padded = fp8_utils.align(scale_k, 4) + workspace_sf = torch.empty( + (num_experts * (scale_k_padded // 4) * m_padded), + dtype=torch.int32, + device='cuda') + + workspace = { + "workspace_0": workspace_0, + "workspace_1": workspace_1, + "workspace_sf": workspace_sf, + } + return workspace + + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True): + if self.quant_config.layer_quant_mode.has_fp8_block_scales(): + return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm() + else: + raise ValueError( + f"Unsupported quantization mode: {self.quant_config.quant_mode}" + ) + else: + return UnquantizedFusedMoEMethod() + @nvtx_range("[DG] forward") def forward_chunk( self, @@ -348,6 +408,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE): output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, use_dp_padding: Optional[bool] = None, + workspace: Optional[dict] = None, ) -> torch.Tensor: if isinstance(x, Fp4QuantizedTensor): assert output_dtype is not None @@ -423,22 +484,38 @@ class DeepGemmFusedMoE(CutlassFusedMoE): masked_m, token_to_expert_map = preprocess_after_permute( expert_first_token_offset_tensor, permuted_data_tensor) - m_max = (x.shape[0] + 127) // 128 * 128 expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition - act_input_fp8 = torch.empty( - (self.expert_size_per_partition, m_max, self.hidden_size), - dtype=torch.float8_e4m3fn, - device='cuda') + + # padding and quantization + m_max = fp8_utils.align(x.shape[0], 128) + act_input_fp8 = set_strides(workspace["workspace_0"], + self.expert_size_per_partition, m_max, + self.hidden_size) + + m_padded = fp8_utils.align(m_max, 4) + scale_k = fp8_utils.ceil_div(self.hidden_size, 128) + scale_k_padded = fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + self.expert_size_per_partition, + scale_k_padded // 4, m_padded) + act_input_sf = masked_index_copy_group_quant_fp8( act_input_fp8, + act_input_sf, permuted_data_tensor, expert_first_token_offset_tensor, token_to_expert_map, group_size=128) - h1 = deepgemm_fp8_group_blockwise_gemm( + # grouped gemm 1 + h1 = set_strides(workspace["workspace_1"], + self.expert_size_per_partition, m_max, + self.intermediate_size * 2) + + deepgemm_fp8_group_blockwise_gemm( + d=h1, a=act_input_fp8, b=self.w3_w1_weight, sfa=act_input_sf, @@ -446,9 +523,33 @@ class DeepGemmFusedMoE(CutlassFusedMoE): masked_m=masked_m, expected_m=expected_m, ) - act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( - input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) - h3 = deepgemm_fp8_group_blockwise_gemm( + + # activation and quantization + act_input_fp8 = set_strides(workspace["workspace_0"], + self.expert_size_per_partition, m_max, + self.intermediate_size) + + scale_k = fp8_utils.ceil_div(self.intermediate_size, 128) + scale_k_padded = fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + self.expert_size_per_partition, + scale_k_padded // 4, m_padded) + + act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( + output=act_input_fp8, + output_scale=act_input_sf, + input=h1, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) + + # grouped gemm 2 + h3 = set_strides(workspace["workspace_1"], + self.expert_size_per_partition, m_max, + self.hidden_size) + + deepgemm_fp8_group_blockwise_gemm( + d=h3, a=act_input_fp8, b=self.w2_weight, sfa=act_input_sf, @@ -457,6 +558,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE): expected_m=expected_m, ) + # gather and finalize triton_masked_index_gather(permuted_data_tensor, h3, expert_first_token_offset_tensor, token_to_expert_map) @@ -481,3 +583,137 @@ class DeepGemmFusedMoE(CutlassFusedMoE): ) return final_hidden_states + + def forward( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + do_finalize: bool = True, # used by other MoE backends + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + all_rank_max_num_tokens: Optional[int] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + assert do_finalize, "CutlassFusedMoE does not support do_finalize=False" + if self.use_dp and self.parallel_size > 1: + assert all_rank_num_tokens is not None + assert use_dp_padding is not None + num_rows = sum(all_rank_num_tokens) + else: + num_rows = x.shape[0] + + # In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks. + # Because we will use two streams in chunked moe and preallocate two workspaces. + num_chunks = 1 + if num_rows > self.moe_max_num_tokens * 2: + num_chunks = (num_rows + self.moe_max_num_tokens - + 1) // self.moe_max_num_tokens + + if use_dp_padding: + all_rank_num_tokens_padded = [all_rank_max_num_tokens + ] * len(all_rank_num_tokens) + else: + all_rank_num_tokens_padded = all_rank_num_tokens + + if num_chunks == 1: + # create workspace + num_rows = x.shape[0] + if self.use_dp: + num_rows = sum(all_rank_num_tokens_padded) + m_max = fp8_utils.align(num_rows, 128) + workspace = self.get_workspace(m_max, 128) + outputs = self.forward_chunk( + x, + router_logits, + output_dtype, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding, + workspace=workspace) + outputs = self.reducescatter_or_allreduce( + outputs, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding) + else: + if self.use_dp: + all_rank_chunk_size_list = [ + self.split_chunk(val, num_chunks) + for val in all_rank_num_tokens_padded + ] + all_rank_num_tokens_list = [[ + val[idx_chunk] for val in all_rank_chunk_size_list + ] for idx_chunk in range(num_chunks)] + chunk_size_list = all_rank_chunk_size_list[self.rank] + else: + all_rank_num_tokens_list = [None] * num_chunks + chunk_size_list = self.split_chunk(x.shape[0], num_chunks) + + # create workspace + chunk_size_0 = sum(all_rank_num_tokens_list[0] + ) if self.use_dp else chunk_size_list[0] + chunk_size_1 = sum(all_rank_num_tokens_list[1] + ) if self.use_dp else chunk_size_list[1] + workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128), + 128) + workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128), + 128) + + x_list = x.split(chunk_size_list) + router_logits_list = router_logits.split(chunk_size_list) + + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() + + def _forward_chunk(x_, router_logits_, idx, workspace): + return self.forward_chunk( + x_, + router_logits_, + all_rank_num_tokens=all_rank_num_tokens_list[idx] + if self.use_dp else None, + use_dp_padding=use_dp_padding, + workspace=workspace) + + def _reducescatter_or_allreduce(x_, idx): + return self.reducescatter_or_allreduce( + x_, + all_rank_num_tokens=all_rank_num_tokens_list[idx], + use_dp_padding=use_dp_padding) + + outputs_list = [] + # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap + for idx_chunk, (x, router_logits) in enumerate( + zip(x_list, router_logits_list)): + + if idx_chunk % 2 == 0: + with torch.cuda.stream(self.aux_stream): + outputs = _forward_chunk(x, router_logits, idx_chunk, + workspace_0) + if idx_chunk > 0: + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) + else: + outputs = _forward_chunk(x, router_logits, idx_chunk, + workspace_1) + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) + + outputs_list.append(outputs) + + if num_chunks % 2 == 0: + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], -1) + else: + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], -1) + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.MoeChunkingOverlap].record() + self.event_dict[EventType.MoeChunkingOverlap].wait() + + outputs = torch.cat(outputs_list) + + if self.use_dp and self.parallel_size > 1: + rank = self.mapping.tp_rank + outputs = outputs[:all_rank_num_tokens[rank]] + return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py new file mode 100755 index 0000000000..f2ef121757 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py @@ -0,0 +1,1389 @@ +from __future__ import annotations + +import os +import sys +from typing import Dict, List, NamedTuple, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from tensorrt_llm.math_utils import ceil_div + +IS_TRITON_KERNELS_AVAILABLE = False +# We expect to find triton_kernels under $TRITON_ROOT/python/triton_kernels +# Triton upstream commit f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f has been verified. +triton_root = os.getenv('TRITON_ROOT') +if triton_root: + triton_root = os.path.abspath( + os.path.join(triton_root, 'python', 'triton_kernels')) + if os.path.exists(triton_root) and triton_root not in sys.path: + sys.path.insert(0, triton_root) + assert triton.__version__ >= "3.4.0", "Triton kernels are detected but the Triton wheel is too old" + import triton_kernels.swiglu + from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation, + PrecisionConfig, matmul_ogs) + from triton_kernels.numerics import InFlexData + from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor + from triton_kernels.tensor_details import layout + IS_TRITON_KERNELS_AVAILABLE = True + +from ...model_config import ModelConfig +from ..linear import TensorParallelMode, load_weight_shard +from .interface import MoE +from .quantization import (FusedMoEMethodBase, MoEWeightLoadingMode, + load_activation_scales_fp8_qdq, + requantize_expert_w3_w1_weight_fp8_qdq) +from .routing import BaseMoeRoutingMethod, RenormalizeMoeRoutingMethod + + +# Triton kernels has hardcoded beta = 1, so we use this implementation when beta is not 1 +def swiglu_torch(a: torch.Tensor, alpha: float, beta: float, + limit: Optional[float]) -> torch.Tensor: + a_glu = a[..., ::2] + if limit is not None: + a_glu = a_glu.clamp(max=limit) + a_linear = a[..., 1::2] + if limit is not None: + a_linear = a_linear.clamp(min=-limit, max=limit) + + out_glu = a_glu * torch.sigmoid(alpha * a_glu) + out = out_glu * (a_linear + beta) + return out + + +def shuffle_weight_for_activation_kernel( + w3_w1_weight: torch.Tensor) -> torch.Tensor: + temp_weight = w3_w1_weight.clone() + last_dim = w3_w1_weight.shape[-1] + assert w3_w1_weight.dim() in [1, 2, 3] + # n_dims = 1: Single expert bias (like the unquantized case) + # n_dims = 2: Single expert weight (like the unquantized case) + # n_dims = 3: Multiple experts weight (re-quantization for fp8 qdq) + w3_w1_weight[..., 0::2] = temp_weight[..., last_dim // 2:] + w3_w1_weight[..., 1::2] = temp_weight[..., 0:last_dim // 2] + return w3_w1_weight + + +# This kernel remaps the global routing information (bitmatrix and indices) +# to a local view for this specific EP worker. +# +# The bitmask is shifted so that the worker's slice of experts starts at bit 0. +# Since the slice may not align with 32-bit word boundaries, this is done by +# loading two consecutive words (v1, v2) and "stitching" the result together. +# The expression `(v1 >> start_bit) | (v2 << (32 - start_bit))` takes the +# upper bits from v1 and combines them with the lower bits from v2 to form the +# new, correctly aligned word. +@triton.jit +def _routing_shift_bitmatrix_range(Bitmatrix, stride_bm, stride_bn, Indices, + stride_im, stride_in, n_words, n_cols, + slice_start, slice_end, + BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + start_word = slice_start // 32 + start_bit = slice_start % 32 + + for col0 in range(0, n_words, BLOCK_N): + w = col0 + tl.arange(0, BLOCK_N) # dst‐word indices + dst_mask = w < n_words + + # corresponding source words (and the next word for carry bits) + src1_w = start_word + w + src2_w = src1_w + 1 + + ptr1 = Bitmatrix + pid_m * stride_bm + src1_w * stride_bn + ptr2 = Bitmatrix + pid_m * stride_bm + src2_w * stride_bn + + v1 = tl.load(ptr1, mask=src1_w < n_words, other=0).to(tl.uint32) + v2 = tl.load(ptr2, mask=src2_w < n_words, other=0).to(tl.uint32) + + # shift the slice down to bit‐0 + shifted = tl.where(start_bit == 0, v1, + (v1 >> start_bit) | (v2 << (32 - start_bit))) + + # write back in place; bits past the region are already zero + tl.store(Bitmatrix + pid_m * stride_bm + w * stride_bn, + shifted.to(tl.int32), + mask=dst_mask) + + # Fix the indices associated with the bitmatrix. + for col0 in range(0, n_cols, BLOCK_N): + offs = col0 + tl.arange(0, BLOCK_N) + mask_i = offs < n_cols + + ptr = Indices + pid_m * stride_im + offs * stride_in + yi = tl.load(ptr, mask=mask_i, other=0).to(tl.int32) + + yi = tl.where(yi < slice_end, yi - slice_start, + yi) # shift inside slice + yi = tl.where(yi < 0, yi + slice_end, yi) # wrap negatives + + tl.store(ptr, yi, mask=mask_i) + + +class TritonEPRouter(): + + def prune_routing_ep(self, expt_scal, expt_indx, bitmatrix, n_expts_tot, + slice_start, slice_end): + from triton_kernels.compaction import compaction + from triton_kernels.routing import _routing_clear_bitmatrix + n_tokens_pad = expt_scal.shape[0] + _routing_shift_bitmatrix_range[(n_tokens_pad, )]( + bitmatrix.storage.data, + bitmatrix.storage.data.stride(0), + bitmatrix.storage.data.stride(1), + expt_indx, + expt_indx.stride(0), + expt_indx.stride(1), + bitmatrix.storage.data.shape[1], + expt_indx.shape[1], + slice_start, + slice_end, + BLOCK_N=512, + ) + _routing_clear_bitmatrix[(n_tokens_pad, )]( + bitmatrix.storage.data, + bitmatrix.storage.data.stride(0), + bitmatrix.storage.data.stride(1), + bitmatrix.storage.data.shape[1], + slice_end - slice_start, + BLOCK_N=512, + ) + # perform compaction to update expt_scal / expt_indx + expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix) + n_expts_tot = slice_end - slice_start + bitmatrix.shape[-1] = n_expts_tot + return expt_scal, expt_indx, bitmatrix + + def __call__(self, + logits, + n_expts_act, + sm_first=False, + expt_indx=None, + ep=1, + node_idx=0, + n_rows=None): + n_expts_tot = logits.shape[-1] + n_expts_local = n_expts_tot // ep + slice_start = node_idx * n_expts_local + slice_end = slice_start + n_expts_local + + from triton_kernels.routing import routing_from_bitmatrix + from triton_kernels.topk import topk + if sm_first: + logits = torch.softmax(logits, dim=-1) + expt_scal, expt_indx, bitmatrix = topk(logits, + n_expts_act, + apply_softmax=not sm_first, + y_indx=expt_indx, + n_rows=n_rows) + # mutate bitmatrix + if ep > 1: + expt_scal, expt_indx, bitmatrix = self.prune_routing_ep( + expt_scal, expt_indx, bitmatrix, n_expts_tot, slice_start, + slice_end) + return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, + n_expts_local, n_expts_act) + + +def maybe_update_stride(weight): + assert weight.dim() == 3 + # For the latest Triton kernels, w.stride(-2)==1 works universally + return weight.transpose(1, 2).contiguous().transpose(1, 2) + + +class TritonUnquantizedFusedMoEMethod(FusedMoEMethodBase): + + def __init__(self, shuffle_weight=True): + super().__init__() + self.shuffle_weight = shuffle_weight + + def create_weights(self, module: torch.nn.Module): + weight_dtype = module.dtype + assert weight_dtype == torch.bfloat16, \ + f"TritonUnquantizedFusedMoEMethod only supports bfloat16 weights, got {weight_dtype}" + + # The Triton kernel accepts the w3_w1_weight in (num_experts, hidden_dim, intermediate_dim * 2) format + w3_w1_weight_shape = (module.expert_size_per_partition, + module.hidden_size, + module.intermediate_size_per_partition * 2) + + # The Triton kernel accepts the w2_weight in (num_experts, intermediate_dim, hidden_dim) format + w2_weight_shape = ( + module.expert_size_per_partition, + module.intermediate_size_per_partition, + module.hidden_size, + ) + super().create_weights(module, + weight_dtype, + w3_w1_weight_shape, + w2_weight_shape, + bias_dtype=torch.float32) + self.setup_quant_scales(module) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = tuple() + + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + return tuple() + + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + """ + Load w1 and w3 weights for each expert. + Override this method if you need to preprocess the weights differently. + """ + device = dst_w3_w1_weight.device + assert device.type == "cuda" + + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + # We use .to here since for Triton the bias is always in float32 and a conversion is needed. + w31_weight_shard = w31_weight_shard.to(dst_w3_w1_weight.dtype) + + # This function is shared by weights and biases, we only do transpose for weights + if w31_weight_shard.dim() == 2: + # Transpose the weights to match the expected format for the Triton gemm kernel + w31_weight_shard = w31_weight_shard.transpose(0, 1).contiguous() + + if self.shuffle_weight: + w31_weight_shard = shuffle_weight_for_activation_kernel( + w31_weight_shard) + + dst_w3_w1_weight.copy_(w31_weight_shard, non_blocking=True) + + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + """ + Load w2 weight for each expert. + Override this method if you need to preprocess the weights differently. + """ + device = dst_w2_weight.device + assert device.type == "cuda" + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + # We use .to here since for Triton the bias is always in float32 and a conversion is needed. + w2_weight_shard = w2_weight_shard.to(dst_w2_weight.dtype) + + # This function is shared by weights and biases, we only do transpose for weights + if w2_weight_shard.dim() == 2: + # Transpose the weights to match the expected format for the Triton gemm kernel + w2_weight_shard = w2_weight_shard.transpose(0, 1).contiguous() + else: + assert w2_weight_shard.dim() == 1 + # Handle TP contribution of bias + w2_weight_shard /= module.tp_size + + dst_w2_weight.copy_(w2_weight_shard, non_blocking=True) + + def load_expert_weights_to_dst( + self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + load_expert_ids: List[int], dst_w3_w1_weights_tensor: torch.Tensor, + dst_w2_weights_tensor: torch.Tensor, + dst_w3_w1_bias_tensor: Optional[torch.Tensor], + dst_w2_bias_tensor: Optional[torch.Tensor]): + FusedMoEMethodBase.load_expert_weights_to_dst( + self, module, weights, weight_loading_mode, load_expert_ids, + dst_w3_w1_weights_tensor, dst_w2_weights_tensor, + dst_w3_w1_bias_tensor, dst_w2_bias_tensor) + module.w3_w1_weight.data = maybe_update_stride(module.w3_w1_weight.data) + module.w2_weight.data = maybe_update_stride(module.w2_weight.data) + + def apply(self, module: torch.nn.Module, x: torch.Tensor, + router_logits: torch.Tensor) -> torch.Tensor: + # Fetch all the data needed for the Triton kernel + hidden_states = x + expert_logits = router_logits + gemm1_weights = module.w3_w1_weight + gemm2_weights = module.w2_weight + top_k = module.routing_method.experts_per_token + + # hidden_states: (num_tokens, hidden_dim) torch.bfloat16 + # expert_logits: (num_tokens, num_experts) torch.bfloat16 + # gemm1_weights: (num_experts, intermediate_dim * 2, hidden_dim) torch.bfloat16 + # gemm2_weights: (num_experts, hidden_dim, intermediate_dim) torch.bfloat16 + + # Step 1: Routing + num_experts = expert_logits.shape[1] + if num_experts > 1: + rdata, gather_indx, scatter_indx = TritonEPRouter()( + expert_logits, + top_k, + ep=module.ep_size, + node_idx=module.ep_rank) + else: + rdata, gather_indx, scatter_indx = None, None, None + + # Step 2: Gemm1 + # Setup quantization context + pc1 = PrecisionConfig(flex_ctx=FlexCtx(), + allow_tf32=False, + out_dtype=module.dtype) + + # Call the Triton gemm kernel, which also does permutation and activation + alpha = module.swiglu_alpha or 1.0 + beta = module.swiglu_beta or 0.0 + if beta == 1.0: + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, + ("alpha", "limit")), (alpha, module.swiglu_limit), 2) + act_out = matmul_ogs(hidden_states, + gemm1_weights, + module.w3_w1_bias if module.bias else None, + rdata, + gather_indx=gather_indx, + precision_config=pc1, + fused_activation=act) + else: + act_out = matmul_ogs(hidden_states, + gemm1_weights, + module.w3_w1_bias if module.bias else None, + rdata, + gather_indx=gather_indx, + precision_config=pc1) + act_out = swiglu_torch(act_out, alpha, beta, module.swiglu_limit) + + # Step 3: Gemm2 + # Setup quantization context + pc2 = PrecisionConfig(flex_ctx=FlexCtx(), + allow_tf32=False, + out_dtype=module.dtype) + + # Call the Triton kernel, which also does finalization + gemm2_output = matmul_ogs(act_out, + gemm2_weights, + module.w2_bias if module.bias else None, + rdata, + scatter_indx=scatter_indx, + precision_config=pc2, + gammas=rdata.gate_scal if rdata else None) + return gemm2_output + + +class TritonFP8QDQFusedMoEQuantScales(NamedTuple): + fc1_dequant: torch.Tensor + fc2_dequant: torch.Tensor + fc1_input_dequant: torch.Tensor + fc2_input_dequant: torch.Tensor + + +# We inherit from TritonUnquantizedFusedMoEMethod to reuse the weight preprocessing logic +class TritonFP8QDQFusedMoEMethod(TritonUnquantizedFusedMoEMethod): + + def __init__(self): + # Due to the requantization logic in the Triton kernel, we delay the shuffle + super().__init__(shuffle_weight=False) + + def create_weights(self, module: torch.nn.Module): + weight_dtype = torch.float8_e4m3fn + + # The Triton kernel accepts the w3_w1_weight in (num_experts, hidden_dim, intermediate_dim * 2) format + w3_w1_weight_shape = ( + module.expert_size_per_partition, + module.hidden_size, + module.intermediate_size_per_partition * 2, + ) + + # The Triton kernel accepts the w2_weight in (num_experts, intermediate_dim, hidden_dim) format + w2_weight_shape = ( + module.expert_size_per_partition, + module.intermediate_size_per_partition, + module.hidden_size, + ) + FusedMoEMethodBase.create_weights(self, + module, + weight_dtype, + w3_w1_weight_shape, + w2_weight_shape, + bias_dtype=torch.float32) + + fc31_dequant = nn.Parameter(torch.empty( + module.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc31_dequant", fc31_dequant) + + fc2_dequant = nn.Parameter(torch.empty(module.expert_size_per_partition, + dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc2_dequant", fc2_dequant) + + fc31_input_dequant = nn.Parameter(torch.tensor(1., dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc31_input_dequant", fc31_input_dequant) + + fc2_input_dequant = nn.Parameter(torch.tensor(1., dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc2_input_dequant", fc2_input_dequant) + + self.setup_quant_scales(module) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = TritonFP8QDQFusedMoEQuantScales( + fc1_dequant=module.fc31_dequant, + fc2_dequant=module.fc2_dequant, + fc1_input_dequant=module.fc31_input_dequant, + fc2_input_dequant=module.fc2_input_dequant, + ) + + def load_expert_w3_w1_weight_scale_fp8_qdq( + self, w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight_scale: torch.Tensor): + w1_weight_scale = w1_weight_scale[...].reshape([]) + w3_weight_scale = w3_weight_scale[...].reshape([]) + dst_w3_w1_weight_scale.copy_(max(w1_weight_scale, w3_weight_scale), + non_blocking=True) + + def load_expert_w2_weight_scale_fp8(self, w2_weight_scale, + dst_w2_weight_scale: torch.Tensor): + dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([]), + non_blocking=True) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Step1: Load input scales. + max_fc31_input_scale, max_fc2_input_scale = load_activation_scales_fp8_qdq( + module, weights) + + # Step2: Load weight scales and requantize w3_w1_weight. + tmp_w3_w1_weight_scale = torch.empty(module.expert_size_per_partition, + dtype=torch.float32) + tmp_w2_weight_scale = torch.empty(module.expert_size_per_partition, + dtype=torch.float32) + + for local_slot_id, expert_id in enumerate( + module.initial_local_expert_ids): + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_weight_scale = weights[f"gate_up_proj_weight_scale"] + w3_weight_scale = weights[f"gate_up_proj_weight_scale"] + w2_weight_scale = weights[f"down_proj_weight_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + + expert_idx = local_slot_id + + self.load_expert_w3_w1_weight_scale_fp8_qdq( + w1_weight_scale, w3_weight_scale, + tmp_w3_w1_weight_scale[expert_idx]) + + # Shared code from FP8QDQFusedMoEMethod, need to pass a transposed view + requantize_expert_w3_w1_weight_fp8_qdq( + module, w1_weight_scale, w3_weight_scale, + module.w3_w1_weight.data[expert_idx].transpose(0, 1)) + + self.load_expert_w2_weight_scale_fp8( + w2_weight_scale, tmp_w2_weight_scale[expert_idx]) + + # now we can shuffle the weights for the activation kernel + module.w3_w1_weight.data = shuffle_weight_for_activation_kernel( + module.w3_w1_weight.data) + if module.bias: + # Bias should also be shuffled here + module.w3_w1_bias.data = shuffle_weight_for_activation_kernel( + module.w3_w1_bias.data) + # Step3: calculate and store final loaded weights + module.fc31_dequant.data.copy_(tmp_w3_w1_weight_scale, + non_blocking=True) + module.fc2_dequant.data.copy_(tmp_w2_weight_scale, non_blocking=True) + module.fc31_input_dequant.data.copy_(max_fc31_input_scale, + non_blocking=True) + module.fc2_input_dequant.data.copy_(max_fc2_input_scale, + non_blocking=True) + + def load_expert_weights_to_dst( + self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + load_expert_ids: List[int], dst_w3_w1_weights_tensor: torch.Tensor, + dst_w2_weights_tensor: torch.Tensor, + dst_w3_w1_bias_tensor: Optional[torch.Tensor], + dst_w2_bias_tensor: Optional[torch.Tensor]): + FusedMoEMethodBase.load_expert_weights_to_dst( + self, module, weights, weight_loading_mode, load_expert_ids, + dst_w3_w1_weights_tensor, dst_w2_weights_tensor, + dst_w3_w1_bias_tensor, dst_w2_bias_tensor) + module.w3_w1_weight.data = maybe_update_stride(module.w3_w1_weight.data) + module.w2_weight.data = maybe_update_stride(module.w2_weight.data) + + def apply(self, module: torch.nn.Module, x: torch.Tensor, + router_logits: torch.Tensor) -> torch.Tensor: + # Fetch all the data needed for the Triton kernel + hidden_states, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, module.fc31_input_dequant) + hidden_states_scale = module.fc31_input_dequant + expert_logits = router_logits + gemm1_weights = module.w3_w1_weight + gemm1_scales = module.fc31_dequant + gemm2_weights = module.w2_weight + gemm2_scales = module.fc2_dequant + top_k = module.routing_method.experts_per_token + + # hidden_states: (num_tokens, hidden_dim) torch.float8_e4m3fn + # hidden_states_scale: (,) torch.float32 + # expert_logits: (num_tokens, num_experts) torch.bfloat16 + # gemm1_weights: (num_experts, hidden_dim, intermediate_dim * 2) torch.float8_e4m3fn + # gemm1_scales: (num_experts, ) torch.float32 + # gemm2_weights: (num_experts, intermediate_dim, hidden_dim) torch.float8_e4m3fn + # gemm2_scales: (num_experts, ) torch.float32 + + # Step 1: Routing + num_experts = expert_logits.shape[1] + if num_experts > 1: + rdata, gather_indx, scatter_indx = TritonEPRouter()( + expert_logits, + top_k, + ep=module.ep_size, + node_idx=module.ep_rank) + else: + rdata, gather_indx, scatter_indx = None, None, None + + # Step 2: Gemm1 + # Setup quantization context + flex_ctx_1 = FlexCtx( + lhs_data=InFlexData(scale=hidden_states_scale), + rhs_data=InFlexData(scale=gemm1_scales), + ) + pc1 = PrecisionConfig(flex_ctx=flex_ctx_1, + allow_tf32=False, + out_dtype=module.dtype) + + # Call the Triton gemm kernel, which also does permutation and activation + alpha = module.swiglu_alpha or 1.0 + beta = module.swiglu_beta or 0.0 + if beta == 1.0: + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, + ("alpha", "limit")), (alpha, module.swiglu_limit), 2) + act_out = matmul_ogs(hidden_states, + gemm1_weights, + module.w3_w1_bias if module.bias else None, + rdata, + gather_indx=gather_indx, + precision_config=pc1, + fused_activation=act) + else: + act_out = matmul_ogs(hidden_states, + gemm1_weights, + module.w3_w1_bias if module.bias else None, + rdata, + gather_indx=gather_indx, + precision_config=pc1) + act_out = swiglu_torch(act_out, alpha, beta, module.swiglu_limit) + + # Quantize the activation output manually since the Triton activation kernel doesn't support bf16 in fp8 out + act_out, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + act_out, module.fc2_input_dequant) + + # Step 3: Gemm2 + # Setup quantization context + flex_ctx_2 = FlexCtx( + lhs_data=InFlexData(scale=module.fc2_input_dequant), + rhs_data=InFlexData(scale=gemm2_scales), + ) + pc2 = PrecisionConfig(flex_ctx=flex_ctx_2, + allow_tf32=False, + out_dtype=module.dtype) + + # Call the Triton kernel, which also does finalization + gemm2_output = matmul_ogs(act_out, + gemm2_weights, + module.w2_bias if module.bias else None, + rdata, + scatter_indx=scatter_indx, + precision_config=pc2, + gammas=rdata.gate_scal if rdata else None) + return gemm2_output + + +class TritonMXFP4FusedMoEQuantScales(NamedTuple): + fc1_dequant: torch.Tensor + fc2_dequant: torch.Tensor + fc1_input_dequant: torch.Tensor + fc2_input_dequant: torch.Tensor + + +def swizzle_weight_and_scale(w: torch.Tensor, w_scale: torch.Tensor): + # (num_experts, in_dim//2, out_dim) + w_shape = w.shape + # (num_experts, in_dim//32, out_dim) + w_scale_shape = w_scale.shape + assert w_shape[0] == w_scale_shape[0] + assert w_shape[1] * 2 == w_scale_shape[1] * 32 + assert w_shape[2] == w_scale_shape[2] + w = maybe_update_stride(w) + #num_warps = 4 if batch <= 512 else 8 + num_warps = int(os.getenv("TRITON_MOE_MXFP4_NUM_WARPS", 4)) + assert num_warps in [4, 8], \ + f"TRITON_MOE_MXFP4_NUM_WARPS should be 4 or 8, got {num_warps}" + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1) + scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps) + # swizzling path is broken for H20 + if torch.cuda.get_device_name() == "NVIDIA H20": + from triton_kernels.tensor_details.layout_details.strided import \ + StridedLayout + value_layout = StridedLayout + value_layout_opts = dict() + scale_layout = StridedLayout + scale_layout_opts = dict() + + opt = {"value_layout": value_layout, "value_layout_opts": value_layout_opts, \ + "scale_layout": scale_layout, "scale_layout_opts": scale_layout_opts} + + # w, w_scale = downcast_to_mxfp(tensor.to(torch.bfloat16), torch.uint8, axis=1) + w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], + **opt["value_layout_opts"]) + w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], + **opt["scale_layout_opts"]) + return w, w_scale + + +# We inherit from TritonUnquantizedFusedMoEMethod to reuse the weight preprocessing logic +class TritonMXFP4FusedMoEMethod(TritonUnquantizedFusedMoEMethod): + + def __init__(self, activation_dtype): + super().__init__(shuffle_weight=True) + assert activation_dtype in [torch.float8_e4m3fn, torch.bfloat16], \ + f"TritonMXFP4FusedMoEMethod only supports float8_e4m3fn or bfloat16 activation, got {activation_dtype}" + self.activation_dtype = activation_dtype + self.in_dim_padding_multiple = 128 + self.out_dim_padding_multiple = 256 + + def create_weights(self, module: torch.nn.Module): + weight_dtype = torch.uint8 + + # The Triton kernel accepts the w3_w1_weight in (num_experts, hidden_dim, intermediate_dim * 2) format + w3_w1_weight_shape = ( + module.expert_size_per_partition, + module.hidden_size // 2, # Two mxfp4 packed to a byte + module.intermediate_size_per_partition * 2, + ) + + # Full scale is loaded at the beginning, later we will slice properly for TP + w3_w1_scale_shape = ( + w3_w1_weight_shape[0], + ceil_div(module.hidden_size, 32), # block size of 32 for mxfp4 + module.intermediate_size * 2, + ) + + # The Triton kernel accepts the w2_weight in (num_experts, intermediate_dim, hidden_dim) format + w2_weight_shape = ( + module.expert_size_per_partition, + module.intermediate_size_per_partition // + 2, # Two mxfp4 packed to a byte, + module.hidden_size, + ) + + # Full scale is loaded at the beginning, later we will slice properly for TP + w2_scale_shape = ( + w2_weight_shape[0], + ceil_div(module.intermediate_size, + 32), # block size of 32 for mxfp4 + w2_weight_shape[2], + ) + + FusedMoEMethodBase.create_weights(self, + module, + weight_dtype, + w3_w1_weight_shape, + w2_weight_shape, + bias_dtype=torch.float32) + + fc31_dequant = nn.Parameter( + torch.empty(w3_w1_scale_shape, dtype=torch.uint8), # mxfp8 scale + requires_grad=False) + module.register_parameter("fc31_dequant", fc31_dequant) + + fc2_dequant = nn.Parameter( + torch.empty(w2_scale_shape, dtype=torch.uint8), # mxfp8 scale + requires_grad=False) + module.register_parameter("fc2_dequant", fc2_dequant) + + if self.activation_dtype == torch.float8_e4m3fn: + fc31_input_dequant = nn.Parameter(torch.tensor(1., + dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc31_input_dequant", fc31_input_dequant) + + fc2_input_dequant = nn.Parameter(torch.tensor(1., + dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc2_input_dequant", fc2_input_dequant) + + self.setup_quant_scales(module) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = TritonMXFP4FusedMoEQuantScales( + fc1_dequant=module.fc31_dequant, + fc2_dequant=module.fc2_dequant, + fc1_input_dequant=getattr( + module, 'fc31_input_dequant', + None), # activation scale exists only for float8_e4m3fn + fc2_input_dequant=getattr( + module, 'fc2_input_dequant', + None), # activation scale exists only for float8_e4m3fn + ) + + def load_expert_weights_to_dst( + self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + load_expert_ids: List[int], dst_w3_w1_weights_tensor: torch.Tensor, + dst_w2_weights_tensor: torch.Tensor, + dst_w3_w1_bias_tensor: Optional[torch.Tensor], + dst_w2_bias_tensor: Optional[torch.Tensor]): + # dynamic quant scales for weights + self.w3_scales = {} + self.w1_scales = {} + self.w2_scales = {} + # Multithread weight load is superseded by prefetch_files() in model_engine.py + # Also, threading adds overhead in order to protect shuffle index cache with critical section. + for local_slot_id, expert_id in enumerate(load_expert_ids): + # expert_idx is the local slot index of current rank + expert_idx = local_slot_id + + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight = weights[f"{expert_id}.w1.weight"] + w3_weight = weights[f"{expert_id}.w3.weight"] + w2_weight = weights[f"{expert_id}.w2.weight"] + if module.bias: + w1_bias = weights[f"{expert_id}.w1.bias"] + w3_bias = weights[f"{expert_id}.w3.bias"] + w2_bias = weights[f"{expert_id}.w2.bias"] + elif weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( + 0, 1) + w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0) + w2_weight = weights["down_proj"][expert_id].transpose( + 0, 1).contiguous() + if module.bias: + w1_w3_bias = weights["gate_up_proj.bias"][expert_id] + w1_bias, w3_bias = w1_w3_bias.chunk(2, dim=0) + w2_bias = weights["down_proj.bias"][expert_id] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {weight_loading_mode}" + ) + + w3_scale, w1_scale = self.load_expert_w3_w1_weight( + module, w1_weight, w3_weight, + dst_w3_w1_weights_tensor[expert_idx]) + + w2_scale = self.load_expert_w2_weight( + module, w2_weight, dst_w2_weights_tensor[expert_idx]) + if w3_scale is not None: + self.w3_scales[expert_id] = w3_scale + if w1_scale is not None: + self.w1_scales[expert_id] = w1_scale + if w2_scale is not None: + self.w2_scales[expert_id] = w2_scale + + if module.bias: + self.load_expert_w3_w1_weight( + module, + w1_bias, + w3_bias, + dst_w3_w1_bias_tensor.data[expert_idx], + is_bias=True) + + self.load_expert_w2_weight(module, + w2_bias, + dst_w2_bias_tensor.data[expert_idx], + is_bias=True) + + def _permute_mxfp4_quantize(self, tensor): + tensor = tensor.transpose(-2, -1).contiguous() + tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor, + torch.uint8, + axis=-2) + return tensor_fp4, tensor_scales + + def load_expert_w3_w1_weight(self, + module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor, + is_bias: bool = False): + """ + Load w1 and w3 weights for each expert. + Override this method if you need to preprocess the weights differently. + """ + device = dst_w3_w1_weight.device + assert device.type == "cuda" + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + if not is_bias and w3_weight_shard.dtype in (torch.bfloat16, + torch.float16, + torch.float32): + # [N, K] -> [K, N] + w3_weight_shard, w3_scales = self._permute_mxfp4_quantize( + w3_weight_shard) + w1_weight_shard, w1_scales = self._permute_mxfp4_quantize( + w1_weight_shard) + cat_dim = 1 + else: + # [N, K] + w3_scales = None + w1_scales = None + cat_dim = 0 + + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], + dim=cat_dim) + + # This function is shared by weights and biases, we only do transpose for weights + if not is_bias and cat_dim == 0: + # Transpose the weights to match the expected format for the Triton gemm kernel + w31_weight_shard = w31_weight_shard.transpose(0, 1).contiguous() + else: + # We use .to here since for Triton the bias is always in float32 and a conversion is needed. + w31_weight_shard = w31_weight_shard.to(dst_w3_w1_weight.dtype) + + if self.shuffle_weight: + w31_weight_shard = shuffle_weight_for_activation_kernel( + w31_weight_shard) + + dst_w3_w1_weight.copy_(w31_weight_shard, non_blocking=True) + return (w3_scales, w1_scales) + + def load_expert_w2_weight(self, + module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor, + is_bias: bool = False): + """ + Load w2 weight for each expert. + Override this method if you need to preprocess the weights differently. + """ + device = dst_w2_weight.device + assert device.type == "cuda" + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + + w2_scales = None + + if is_bias: + # We use .to here since for Triton the bias is always in float32 and a conversion is needed. + w2_weight_shard = w2_weight_shard.to(dst_w2_weight.dtype) + assert w2_weight_shard.dim() == 1 + # Handle TP contribution of bias + w2_weight_shard /= module.tp_size + else: + if w2_weight_shard.dtype in (torch.bfloat16, torch.float16, + torch.float32): + # [N, K] -> [K, N] + w2_weight_shard, w2_scales = self._permute_mxfp4_quantize( + w2_weight_shard) + else: + # Transpose the weights to match the expected format for the Triton gemm kernel + # [N, K] -> [K, N] + w2_weight_shard = w2_weight_shard.transpose(0, 1).contiguous() + + dst_w2_weight.copy_(w2_weight_shard, non_blocking=True) + + return w2_scales + + def _load_expert_w3_w1_weight_scale_mxfp4( + self, w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight_scale: torch.Tensor, transpose_scales: bool): + if transpose_scales: + # (intermediate_dim * 2, hidden_dim / 32) + combined_scale = torch.cat([w3_weight_scale, w1_weight_scale], + dim=0) + # (hidden_dim / 32, intermediate_dim * 2) + combined_scale = combined_scale.transpose(0, 1) + else: + # (hidden_dim / 32, intermediate_dim * 2) + combined_scale = torch.cat([w3_weight_scale, w1_weight_scale], + dim=1) + + dst_w3_w1_weight_scale.copy_(combined_scale, non_blocking=True) + + def _load_expert_w2_weight_scale_mxfp4(self, w2_weight_scale, + dst_w2_weight_scale: torch.Tensor, + transpose_scales: bool): + if transpose_scales: + w2_weight_scale = w2_weight_scale.transpose( + 0, 1) # (intermediate_dim / 32, hidden_dim) + dst_w2_weight_scale.copy_(w2_weight_scale, non_blocking=True) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Step1: Load input scales. + if self.activation_dtype == torch.float8_e4m3fn: + try: + max_fc31_input_scale, max_fc2_input_scale = load_activation_scales_fp8_qdq( + module, weights) + except KeyError: + # We will use dynamic quantization + max_fc31_input_scale = None + max_fc2_input_scale = None + + # Step2: Load weight scales + device = module.w3_w1_weight.device + tmp_w3_w1_weight_scale = torch.empty(module.fc31_dequant.shape, + dtype=torch.uint8, + device=device) + tmp_w2_weight_scale = torch.empty(module.fc2_dequant.shape, + dtype=torch.uint8, + device=device) + for local_slot_id, expert_id in enumerate( + module.initial_local_expert_ids): + try: + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + need_to_transpose_scales = True + elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + # Reverse-engineered from the openai modeling class + combined_weight_scale = weights[ + "gate_up_proj_weight_scale"][expert_id] + out_dim = combined_weight_scale.shape[-1] + w1_weight_scale = combined_weight_scale[..., :out_dim // 2] + w3_weight_scale = combined_weight_scale[..., out_dim // 2:] + w2_weight_scale = weights[f"down_proj_weight_scale"][ + expert_id] + need_to_transpose_scales = False + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + except KeyError: + # We will use dynamic quantization + w1_weight_scale = self.w1_scales[expert_id] + w3_weight_scale = self.w3_scales[expert_id] + w2_weight_scale = self.w2_scales[expert_id] + need_to_transpose_scales = False + + expert_idx = local_slot_id + + self._load_expert_w3_w1_weight_scale_mxfp4( + w1_weight_scale, w3_weight_scale, + tmp_w3_w1_weight_scale[expert_idx], need_to_transpose_scales) + + self._load_expert_w2_weight_scale_mxfp4( + w2_weight_scale, tmp_w2_weight_scale[expert_idx], + need_to_transpose_scales) + + self.w1_scales.clear() + self.w3_scales.clear() + self.w2_scales.clear() + + # Scales need to be shuffled as well + tmp_w3_w1_weight_scale = shuffle_weight_for_activation_kernel( + tmp_w3_w1_weight_scale) + + # For Hopper style swizzle, we need to pad the out dim to multiple of 256 otherwise it sometimes produces nan + def _maybe_pad_weight_and_scale(weight, + scale=None, + in_dim_padding_offset=0): + # Both weight and bias are handled here + assert weight.dim() in [2, 3], "Weight should be 2D or 3D tensor" + # out_dim padding is only required for Hopper + if torch.cuda.get_device_capability()[0] == 9: + out_dim = weight.shape[-1] + assert scale is None or scale.shape[ + -1] == out_dim, "Out dim of weight and scale should match" + pad_size = (self.out_dim_padding_multiple - + out_dim % self.out_dim_padding_multiple + ) % self.out_dim_padding_multiple + weight = F.pad( + weight, + (0, pad_size)) # Pad the last dimension on right side + if scale is not None: + scale = F.pad(scale, (0, pad_size)) + # in_dim padding is always required when we have TP because of mxfp4 scale block size + # We only do in_dim padding for weights but not for bias + if weight.dim() == 3: + in_dim = weight.shape[ + -2] * 2 # mxfp4 packs two values into one byte + assert scale is None or scale.shape[-2] == ceil_div( + in_dim, 32), "In dim of weight and scale should match" + pad_size = (self.in_dim_padding_multiple - + in_dim % self.in_dim_padding_multiple + ) % self.in_dim_padding_multiple + assert pad_size % 2 == 0 + pad_size //= 2 # pad_size is in mxfp4 units + assert in_dim_padding_offset % 2 == 0 + in_dim_padding_offset //= 2 + assert in_dim_padding_offset <= pad_size, "TP offset larger than pad size" + weight = F.pad(weight, (0, 0, in_dim_padding_offset, + pad_size - in_dim_padding_offset)) + assert scale is not None # Bias won't enter this branch + new_in_dim = weight.shape[-2] * 2 + assert new_in_dim % 32 == 0 + new_scale_in_dim = new_in_dim // 32 + scale_pad_size = new_scale_in_dim - scale.shape[-2] + assert scale_pad_size >= 0 + scale = F.pad(scale, (0, 0, 0, scale_pad_size)) + + return (weight, scale) if scale is not None else weight + + # Handle w3_w1_weight + + # Slice scales for TP + tp_slice_start = module.intermediate_size_per_partition * module.tp_rank + tp_slice_end = tp_slice_start + module.intermediate_size_per_partition + #(num_experts, in_dim / 32, out_dim) + assert tmp_w3_w1_weight_scale.dim() == 3 + assert tmp_w3_w1_weight_scale.shape[-1] == module.intermediate_size * 2 + # The scale is already shuffled + tmp_w3_w1_weight_scale = tmp_w3_w1_weight_scale[:, :, tp_slice_start * + 2:tp_slice_end * 2] + + tmp_w3_w1_weight, tmp_w3_w1_weight_scale = _maybe_pad_weight_and_scale( + module.w3_w1_weight, tmp_w3_w1_weight_scale) + + module._parameters.pop('w3_w1_weight', None) + module._parameters.pop('fc31_dequant', None) + torch.cuda.empty_cache() + + tmp_w3_w1_weight, tmp_w3_w1_weight_scale = swizzle_weight_and_scale( + tmp_w3_w1_weight, tmp_w3_w1_weight_scale) + + module.w3_w1_weight = tmp_w3_w1_weight + module.fc31_dequant = tmp_w3_w1_weight_scale + + # Handle w2_weight + + # Slice scales for TP + # TP might make the weight start from half of the mxfp4 32 block + # For example, if we start from index 20, there are 12 elements in the first block instead of 32 + # We need to pad 20 elements to the first block + self.w2_tp_offset = tp_slice_start % 32 + assert tmp_w2_weight_scale.dim() == 3 + # assert tmp_w2_weight_scale.shape[-2] * 32 == module.intermediate_size + # We skip this assert to allow intermidiate_size not divisible by 32, this is used in the unit test to test TP shapes in a single gpu + scale_slice_start = tp_slice_start // 32 + scale_slice_end = (tp_slice_end - 1) // 32 + 1 + tmp_w2_weight_scale = tmp_w2_weight_scale[:, scale_slice_start: + scale_slice_end, :] + + tmp_w2_weight, tmp_w2_weight_scale = _maybe_pad_weight_and_scale( + module.w2_weight, tmp_w2_weight_scale, self.w2_tp_offset) + + module._parameters.pop('w2_weight', None) + module._parameters.pop('fc2_dequant', None) + torch.cuda.empty_cache() + + tmp_w2_weight, tmp_w2_weight_scale = swizzle_weight_and_scale( + tmp_w2_weight, tmp_w2_weight_scale) + + module.w2_weight = tmp_w2_weight + module.fc2_dequant = tmp_w2_weight_scale + + # Bias needs to be padded as well. + if module.bias: + module.w3_w1_bias.data = _maybe_pad_weight_and_scale( + module.w3_w1_bias.data) + module.w2_bias.data = _maybe_pad_weight_and_scale( + module.w2_bias.data) + + if self.activation_dtype == torch.float8_e4m3fn: + if max_fc31_input_scale is None or max_fc2_input_scale is None: + module.fc31_input_dequant = None + module.fc2_input_dequant = None + else: + module.fc31_input_dequant.data.copy_(max_fc31_input_scale, + non_blocking=True) + module.fc2_input_dequant.data.copy_(max_fc2_input_scale, + non_blocking=True) + + def apply(self, module: torch.nn.Module, x: torch.Tensor, + router_logits: torch.Tensor) -> torch.Tensor: + # Fetch all the data needed for the Triton kernel + if self.activation_dtype == torch.float8_e4m3fn: + if module.fc31_input_dequant is None: + hidden_states, hidden_states_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( + x) + else: + hidden_states, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, module.fc31_input_dequant) + hidden_states_scale = module.fc31_input_dequant + else: + hidden_states = x + expert_logits = router_logits + gemm1_weights = module.w3_w1_weight + gemm1_scales = module.fc31_dequant + gemm2_weights = module.w2_weight + gemm2_scales = module.fc2_dequant + top_k = module.routing_method.experts_per_token + + # hidden_states: (num_tokens, hidden_dim) torch.float8_e4m3fn + # hidden_states_scale: (,) torch.float32 + # expert_logits: (num_tokens, num_experts) torch.bfloat16 + # gemm1_weights: (num_experts, hidden_dim / 2, intermediate_dim * 2) torch.uint8 + # gemm1_scales: (num_experts, hidden_dim / 32, intermediate_dim * 2) torch.uint8 + # gemm2_weights: (num_experts, intermediate_dim / 2, hidden_dim) torch.uint8 + # gemm2_scales: (num_experts, intermediate_dim / 32, hidden_dim) torch.float32 + + # Step 1: Routing + num_experts = expert_logits.shape[1] + if num_experts > 1: + rdata, gather_indx, scatter_indx = TritonEPRouter()( + expert_logits, + top_k, + ep=module.ep_size, + node_idx=module.ep_rank) + else: + rdata, gather_indx, scatter_indx = None, None, None + + # Step 2: Gemm1 + # Setup quantization context + def _maybe_pad_activation(hidden_states, in_dim_padding_offset): + assert hidden_states.dim() == 2, "Hidden states should be 2D tensor" + in_dim = hidden_states.shape[-1] + pad_size_in = (self.in_dim_padding_multiple - + in_dim % self.in_dim_padding_multiple + ) % self.in_dim_padding_multiple + assert in_dim_padding_offset <= pad_size_in + padding = (in_dim_padding_offset, + pad_size_in - in_dim_padding_offset) + hidden_states = F.pad(hidden_states, padding) + return hidden_states + + if self.activation_dtype == torch.float8_e4m3fn: + flex_ctx_1 = FlexCtx( + lhs_data=InFlexData(scale=hidden_states_scale), ) + else: + flex_ctx_1 = FlexCtx() + pc1 = PrecisionConfig(weight_scale=gemm1_scales, + flex_ctx=flex_ctx_1, + allow_tf32=False, + out_dtype=module.dtype) + + # Call the Triton gemm kernel, which also does permutation and activation + alpha = module.swiglu_alpha or 1.0 + beta = module.swiglu_beta or 0.0 + hidden_states = _maybe_pad_activation(hidden_states, 0) + if beta == 1.0: + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, + ("alpha", "limit")), (alpha, module.swiglu_limit), 2) + + act_out = matmul_ogs(hidden_states, + gemm1_weights, + module.w3_w1_bias if module.bias else None, + rdata, + gather_indx=gather_indx, + precision_config=pc1, + fused_activation=act) + else: + act_out = matmul_ogs(hidden_states, + gemm1_weights, + module.w3_w1_bias if module.bias else None, + rdata, + gather_indx=gather_indx, + precision_config=pc1) + act_out = swiglu_torch(act_out, alpha, beta, module.swiglu_limit) + + def _maybe_remove_padding(gemm_output, expected_size): + assert gemm_output.dim() == 2 + if gemm_output.shape[-1] != expected_size: + assert gemm_output.shape[ + -1] % 128 == 0, "The padding is not done correctly" + gemm_output = gemm_output[:, :expected_size] + return gemm_output + + act_out = _maybe_remove_padding( + act_out, module.intermediate_size_per_partition).contiguous() + + if self.activation_dtype == torch.float8_e4m3fn: + # Quantize the activation output manually since the Triton activation kernel doesn't support bf16 in fp8 out + if module.fc2_input_dequant is None: + act_out, act_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( + act_out) + else: + act_out, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + act_out, module.fc2_input_dequant) + act_scale = module.fc2_input_dequant + + # Step 3: Gemm2 + # Setup quantization context + if self.activation_dtype == torch.float8_e4m3fn: + flex_ctx_2 = FlexCtx(lhs_data=InFlexData(scale=act_scale), ) + else: + flex_ctx_2 = FlexCtx() + pc2 = PrecisionConfig(weight_scale=gemm2_scales, + flex_ctx=flex_ctx_2, + allow_tf32=False, + out_dtype=module.dtype) + + # Call the Triton kernel, which also does finalization + act_out = _maybe_pad_activation(act_out, self.w2_tp_offset) + gemm2_output = matmul_ogs(act_out, + gemm2_weights, + module.w2_bias if module.bias else None, + rdata, + scatter_indx=scatter_indx, + precision_config=pc2, + gammas=rdata.gate_scal if rdata else None) + gemm2_output = _maybe_remove_padding(gemm2_output, module.hidden_size) + + return gemm2_output + + +class TritonFusedMoE(MoE): + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. + VANILLA, + bias: bool = False, + layer_idx: Optional[int] = None, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, + ): + super().__init__( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + weight_loading_mode=weight_loading_mode, + ) + if not IS_TRITON_KERNELS_AVAILABLE: + raise ImportError("Triton kernels are not available.") + if torch.cuda.get_device_capability()[0] != 9 and self.ep_size > 1: + raise NotImplementedError( + "TritonFusedMoE is only supported on Hopper with EP size > 1.") + + assert isinstance(self.routing_method, RenormalizeMoeRoutingMethod), \ + "routing_method must be an instance of RenormalizeMoeRoutingMethod for TritonFusedMoE" + assert not self.smart_router, "Smart router is not supported in TritonFusedMoE." + + self.num_slots = self.num_experts + self.expert_size_per_partition = self.num_experts // self.ep_size + self.initial_global_assignments = [ + (ep_rank * self.num_experts // self.ep_size + local_slot_id) % + self.num_experts for ep_rank in range(self.ep_size) + for local_slot_id in range(self.expert_size_per_partition) + ] + self.slot_start = self.ep_rank * self.expert_size_per_partition + self.slot_end = self.slot_start + self.expert_size_per_partition + self.initial_local_expert_ids = self.initial_global_assignments[ + self.slot_start:self.slot_end] + assert len( + self.initial_local_expert_ids) == self.expert_size_per_partition + + self.bias = bias + + def _maybe_squeeze_act_param(p): + if p is None or isinstance(p, (int, float)): + return p + assert isinstance(p, torch.Tensor) + assert p.dtype == torch.float32 + assert p.shape == (self.expert_size_per_partition, ), p.shape + assert torch.all( + p == p[0] + ), "All experts must have the same swiglu alpha/beta for Triton kernel" + p = p[0].item() + return p + + self.swiglu_alpha = _maybe_squeeze_act_param(swiglu_alpha) + self.swiglu_beta = _maybe_squeeze_act_param(swiglu_beta) + self.swiglu_limit = _maybe_squeeze_act_param(swiglu_limit) + + self._weights_created = False + if not model_config.skip_create_weights_in_init: + self.create_weights() + + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True): + if self.quant_config.layer_quant_mode.has_fp8_qdq(): + return TritonFP8QDQFusedMoEMethod() + elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8(): + return TritonMXFP4FusedMoEMethod( + activation_dtype=torch.float8_e4m3fn) + elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4(): + assert self.dtype in ( + torch.bfloat16, torch.float16 + ), "Only bfloat16 and float16 are supported for w4a16_mxfp4" + return TritonMXFP4FusedMoEMethod(activation_dtype=self.dtype) + else: + return TritonUnquantizedFusedMoEMethod() + + def create_weights(self): + if self._weights_created: + return + + self.quant_method = self._get_quant_method() + self.quant_method.create_weights(self) + + self._weights_created = True + + def forward( + self, + x: torch.Tensor, + router_logits: torch.Tensor, + do_finalize: bool = True, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + assert do_finalize, "TritonFusedMoE does not support do_finalize=False" + assert use_dp_padding is None or not use_dp_padding, \ + "TritonFusedMoE does not support use_dp_padding=True" + + hidden_states = self.quant_method.apply(self, x, router_logits) + + final_hidden_states = self.reducescatter_or_allreduce( + hidden_states, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding) + + return final_hidden_states + + def load_weights(self, weights: List[Dict]): + assert self._weights_created + assert len(weights) == 1 + weights = weights[0] + + self.quant_method.load_weights(self, weights, self.weight_loading_mode) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 94e082a667..a74d8f2e73 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -1,13 +1,16 @@ from typing import Dict, List, Optional, Union import torch +from torch import nn -from ...distributed.ops import reducescatter from ...model_config import ModelConfig -from ...utils import Fp4QuantizedTensor +from ...utils import Fp4QuantizedTensor, next_positive_power_of_2 from .interface import MoE, MoEWeightLoadingMode from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, - NVFP4TRTLLMGenFusedMoEMethod) + NVFP4TRTLLMGenFusedMoEMethod, + W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, + W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, + W4A16MXFP4TRTLLMGenFusedMoEMethod) from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod @@ -20,14 +23,13 @@ class TRTLLMGenFusedMoE(MoE): top_k (int): Number of top experts to select for each input token. hidden_size (int): Size of the hidden state. intermediate_size (int): Size of the intermediate state. - aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. MoE torch custom op: Only support min-latency mode now (SM100 Blackwell only). - Quant: fp8 block scales quant and nvfp4 quant + Quant: fp8 block scales quant and nvfp4 quant and w4a16_mxfp4 quant FusedMoE Op: routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute FusedMoE module: @@ -56,6 +58,10 @@ class TRTLLMGenFusedMoE(MoE): weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, layer_idx: Optional[int] = None, + bias: bool = False, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, ): super().__init__( routing_method=routing_method, @@ -66,6 +72,10 @@ class TRTLLMGenFusedMoE(MoE): reduce_results=reduce_results, model_config=model_config, weight_loading_mode=weight_loading_mode, + bias=bias, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, ) assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE." @@ -89,7 +99,33 @@ class TRTLLMGenFusedMoE(MoE): self.create_weights() def _check_configs(self): - assert self.has_deepseek_fp8_block_scales or self.has_nvfp4, "TRTLLMGenFusedMoE only supports fp8_block_scaling and nvfp4 dtypes." + assert self.has_deepseek_fp8_block_scales \ + or self.has_nvfp4 or self.has_w4a16_mxfp4 \ + or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes." + + if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None: + assert self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports mxfp4 quantization with bias, swiglu_alpha, swiglu_beta and swiglu_limit." + + def _get_tile_tokens_dim(self, x: torch.Tensor): + top_k = self.routing_method.top_k + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # 1.0 means perfect expert distribution. + # > 1.0 means some experts have more tokens than the perfect distribution. + # < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // self.num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim def _get_quant_method(self): if self.quant_config is not None: @@ -97,6 +133,12 @@ class TRTLLMGenFusedMoE(MoE): return DeepSeekFP8BlockScalesFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_nvfp4(): return NVFP4TRTLLMGenFusedMoEMethod() + elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4(): + return W4A16MXFP4TRTLLMGenFusedMoEMethod() + elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8(): + return W4A8MXFP4FP8TRTLLMGenFusedMoEMethod() + elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8(): + return W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod() else: raise NotImplementedError( f"Unsupported quantization method by TRTLLMGenFusedMoE: {self.quant_config.quant_mode}" @@ -105,24 +147,6 @@ class TRTLLMGenFusedMoE(MoE): raise NotImplementedError( "TRTLLMGenFusedMoE doesn't support fp16/bf16/fp32 MoE.") - def reducescatter_or_allreduce( - self, - inputs, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - ): - outputs = inputs - if self.parallel_size > 1: - if self.use_dp: - outputs = reducescatter( - inputs, - self.mapping, - dim=0, - sizes=None if use_dp_padding else all_rank_num_tokens) - elif self.reduce_results: - outputs = self.all_reduce(inputs) - return outputs - def create_weights(self): if self._weights_created: return @@ -133,6 +157,20 @@ class TRTLLMGenFusedMoE(MoE): self._weights_created = True self._check_configs() + # TODO: FIX this. + if (self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 + or self.has_w4a8_mxfp4_mxfp8) and not self.bias: + self.w3_w1_bias = nn.Parameter(torch.zeros( + (self.w3_w1_weight.shape[0], self.w3_w1_weight.shape[1]), + dtype=torch.float32), + requires_grad=False) + self.register_parameter("w3_w1_bias", self.w3_w1_bias) + self.w2_bias = nn.Parameter(torch.zeros( + (self.w2_weight.shape[0], self.w2_weight.shape[1]), + dtype=torch.float32), + requires_grad=False) + self.register_parameter("w2_bias", self.w2_bias) + def load_weights(self, weights: List[Dict]): assert self._weights_created @@ -235,9 +273,118 @@ class TRTLLMGenFusedMoE(MoE): return outputs else: final_hidden_states = outputs[0] + elif self.has_w4a16_mxfp4: + assert x.dtype == torch.bfloat16 + + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + x = torch.nn.functional.pad(x, (0, pad_size)) + intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ + -2] // 2 + final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner( + router_logits, + routing_bias, + x, + self.w3_w1_weight, + self.w3_w1_weight_scale, + self.w3_w1_bias, + self.swiglu_alpha, + self.swiglu_beta, + self.swiglu_limit, + self.w2_weight, + self.w2_weight_scale, + self.w2_bias, + self.num_slots, + top_k, + n_group, + topk_group, + intermediate_size_per_partition_padded, + self. + slot_start, # local_expert_start; use ep_rank if stride!=1 + self.expert_size_per_partition, # local_expert_size + routed_scaling_factor, + self._get_tile_tokens_dim(x), + self.routing_method.routing_method_type, + 0, # act_type + ) + final_hidden_states = final_hidden_states[:, :self. + hidden_size].contiguous() + elif self.has_w4a8_mxfp4_fp8: + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + x = torch.nn.functional.pad(x, (0, pad_size)) + x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, self.fc31_input_gate_dequant[0]) + intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ + -2] // 2 + + final_hidden_states = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner( + router_logits, + routing_bias, + x, + self.w3_w1_weight, + self.w3_w1_weight_scale, + self.w3_w1_bias, + self.swiglu_alpha, + self.swiglu_beta, + self.swiglu_limit, + self.w2_weight, + self.w2_weight_scale, + self.w2_bias, + self.fc31_input_dequant, # output1_scales_scalar + self.fc31_input_gate_dequant, # output1_scales_gate_scalar + self.fc2_input_dequant, # output2_scales_scalar + self.num_slots, + top_k, + n_group, + topk_group, + intermediate_size_per_partition_padded, + self. + slot_start, # local_expert_start; use ep_rank if stride!=1 + self.expert_size_per_partition, # local_expert_size + routed_scaling_factor, + self._get_tile_tokens_dim(x), + self.routing_method.routing_method_type, + 0, # act_type + ) + final_hidden_states = final_hidden_states[:, :self. + hidden_size].contiguous() + elif self.has_w4a8_mxfp4_mxfp8: + # TRTLLM-Gen uses linear SF layout for the mxfp8 input. + mxfp8_x, sf = torch.ops.trtllm.mxfp8_quantize( + x, False, alignment=self.quant_method.weight_alignment) + intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ + -2] // 2 + + final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner( + router_logits, + routing_bias, + mxfp8_x, + sf, + self.w3_w1_weight, + self.w3_w1_weight_scale, + self.w3_w1_bias, + self.swiglu_alpha, + self.swiglu_beta, + self.swiglu_limit, + self.w2_weight, + self.w2_weight_scale, + self.w2_bias, + self.num_slots, + top_k, + n_group, + topk_group, + intermediate_size_per_partition_padded, + self.hidden_size, + self. + slot_start, # local_expert_start; use ep_rank if stride!=1 + self.expert_size_per_partition, # local_expert_size + routed_scaling_factor, + self._get_tile_tokens_dim(x), + self.routing_method.routing_method_type, + 0, # act_type + ) else: raise NotImplementedError( - "TRTLLMGenFusedMoE only supports fp8_block_scaling and nvfp4 dtypes." + "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4 and w4a8_mxfp4_fp8 dtypes." ) final_hidden_states = self.reducescatter_or_allreduce( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py index 3249bac979..ed6f11993b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py @@ -81,13 +81,9 @@ class VanillaMoE(nn.ModuleList): self.num_experts) self.expert_size_per_partition = self.expert_end - self.expert_start - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - if self.use_dp: - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = (model_config.moe_max_num_tokens - if model_config.moe_max_num_tokens - is not None else max_num_tokens) + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens self._weights_created = False if not model_config.skip_create_weights_in_init: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 23c683d449..9fee27e6c9 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -6,12 +6,13 @@ import torch from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo from tensorrt_llm._utils import logger +from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.mapping import Mapping -from ...distributed import allgather, reducescatter +from ...distributed import AllReduce, allgather, reducescatter from ...expert_statistic import ExpertStatistic from ...model_config import ModelConfig -from ...utils import EventType, Fp4QuantizedTensor, swizzle_sf +from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .deep_ep_utils import buffer_pool, deep_ep_installed from .interface import MoE from .moe_load_balancer import get_moe_load_balancer @@ -43,7 +44,7 @@ class WideEPMoE(MoE): top_k (int): Number of top experts to select for each input token. hidden_size (int): Size of the hidden state. intermediate_size (int): Size of the intermediate state. - aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. @@ -64,7 +65,8 @@ class WideEPMoE(MoE): dtype: Optional[torch.dtype] = None, reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), - aux_stream: Optional[torch.cuda.Stream] = None, + aux_stream_dict: Optional[Dict[AuxStreamType, + torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, apply_router_weight_on_input: bool = False, @@ -106,7 +108,11 @@ class WideEPMoE(MoE): top_k = self.routing_method.experts_per_token self.expert_size_per_partition = moe_load_balancer_config.num_local_slots self.layer_load_balancer = moe_load_balancer.add_layer( - self.num_experts, top_k, self.expert_size_per_partition) + self.num_experts, + top_k, + self.expert_size_per_partition, + aux_stream=None if aux_stream_dict is None else + aux_stream_dict[AuxStreamType.MoeBalancer]) self.repeat_count = self.layer_load_balancer.get_repeat_count() loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments( self.layer_idx) @@ -130,6 +136,12 @@ class WideEPMoE(MoE): assert num_experts % self.ep_size == 0 self.expert_size_per_partition = num_experts // self.ep_size self.num_slots = num_experts + if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( + ): + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=AllReduceStrategy.NCCL) + else: + self.allreduce = None self.slot_start = self.ep_rank * self.expert_size_per_partition self.slot_end = self.slot_start + self.expert_size_per_partition @@ -138,14 +150,15 @@ class WideEPMoE(MoE): assert len( self.initial_local_expert_ids) == self.expert_size_per_partition - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < max_num_tokens: - self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream( - ) + if self.moe_max_num_tokens < moe_max_num_tokens: + self.aux_stream = aux_stream_dict[ + AuxStreamType. + MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( + ) self.event_dict = { key: torch.cuda.Event() for key in [EventType.Main, EventType.MoeChunkingOverlap] @@ -170,11 +183,15 @@ class WideEPMoE(MoE): f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}", key="alltoall_method_type") self.use_postquant_alltoall = False + self.use_low_precision_combine = False if self.enable_alltoall: qm = self.quant_config.quant_mode self.use_postquant_alltoall = (os.environ.get( "TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1") and qm.has_nvfp4() + self.use_low_precision_combine = (os.environ.get( + "TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0") + == "1") and qm.has_nvfp4() # TODO: support alltoall without allgather for top_k % 4 != 0 self.enable_alltoall_without_allgather = ( os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER", @@ -371,12 +388,11 @@ class WideEPMoE(MoE): is_first_call, is_last_call = repeating_info - if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ) and is_first_call: - self.layer_load_balancer.wait_for_gpu_stage() + if self.layer_load_balancer and is_first_call: + self.layer_load_balancer.start_wait_gpu_stage() use_deepseek_fp8_block_scale = False - use_w4a8_group_scaling = False + use_w4_group_scaling = False weight_dtype = self.w3_w1_weight.dtype token_selected_experts, token_final_scales = self.routing_method.apply( @@ -401,40 +417,32 @@ class WideEPMoE(MoE): else: token_final_scales = None - if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ) and is_first_call: - self.layer_load_balancer.maybe_cudagraph_done_wait() - - use_allgather = not use_all_to_all - - loadbalancer_local_statistic_info = None - gathered_loadbalancer_local_statistic_info = None - token_selected_experts_for_statistic = None - if self.layer_load_balancer is None: - token_selected_slots = token_selected_experts - else: - if not self.layer_load_balancer.is_static_routing( - ) and use_all_to_all: - self.layer_load_balancer.local_statistic( + if self.layer_load_balancer: + if is_first_call: + self.layer_load_balancer.done_wait_gpu_stage() + if use_all_to_all and self.alltoall_method_type == AlltoallMethodType.MNNVL: + self.layer_load_balancer.update_local_statistic( token_selected_experts, is_first_stage=is_first_call, is_last_stage=is_last_call) + else: + self.layer_load_balancer.update_statistic_with_local_ids( + token_selected_experts, + is_first_stage=is_first_call, + is_last_stage=is_last_call, + allreduce=self.allreduce) token_selected_slots = self.layer_load_balancer.route( token_selected_experts, self.use_dp) - if not self.layer_load_balancer.is_static_routing(): - # split into two part to get possible overlap with load balancer routing - if use_all_to_all: - if is_last_call: - loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( - ) - else: - token_selected_experts_for_statistic = token_selected_experts + else: + token_selected_slots = token_selected_experts # If load balancer is disabled, the statistics are collected from expert IDs. # If load balancer is enabled, the statistics are collected from expert slot IDs. ExpertStatistic.set_layer(self.layer_idx) ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) + use_allgather = not use_all_to_all + # If alltoall is disabled, we need also disable use_postquant_alltoall use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all @@ -456,6 +464,11 @@ class WideEPMoE(MoE): self.dummy_allreduce() token_count = x.shape[0] alltoall_info = None + if is_last_call: + loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( + ) + else: + loadbalancer_local_statistic_info = None x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \ self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens, x, @@ -463,6 +476,11 @@ class WideEPMoE(MoE): token_final_scales, use_postquant_alltoall, loadbalancer_local_statistic_info) + if gathered_loadbalancer_local_statistic_info is not None: + gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( + (self.mapping.moe_ep_size, self.num_experts)) + self.layer_load_balancer.update_statistic_with_gathered_statistic( + gathered_loadbalancer_local_statistic_info) elif self.alltoall_method_type == AlltoallMethodType.DeepEP: if not use_postquant_alltoall: x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ @@ -470,10 +488,6 @@ class WideEPMoE(MoE): self.expert_size_per_partition * self.mapping.moe_ep_rank) padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors( x, None, recv_topk_idx, token_final_scales) - if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ): - gathered_loadbalancer_local_statistic_info = allgather( - loadbalancer_local_statistic_info, self.mapping, dim=0) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: if not use_postquant_alltoall: deep_ep_topk_idx = token_selected_slots @@ -503,10 +517,6 @@ class WideEPMoE(MoE): x.shape[0], 1) token_final_scales = torch.ones_like( token_selected_slots, dtype=token_final_scales.dtype) - if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ): - gathered_loadbalancer_local_statistic_info = allgather( - loadbalancer_local_statistic_info, self.mapping, dim=0) x_sf = None x_row = x.shape[0] @@ -533,13 +543,13 @@ class WideEPMoE(MoE): self.fc31_input_scale, self.scaling_vector_size, sfUseUE8M0=False, - swizzedLayout=False) + isSfSwizzledLayout=False) x_sf = x_sf.view((x_row, -1)) elif self.has_deepseek_fp8_block_scales: use_deepseek_fp8_block_scale = True elif self.has_w4afp8: - use_w4a8_group_scaling = True + use_w4_group_scaling = True weight_dtype = torch.quint4x2 else: raise ValueError( @@ -550,35 +560,17 @@ class WideEPMoE(MoE): # using allgather case. if self.enable_dummy_allreduce: self.dummy_allreduce() - x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather( + x, x_sf, token_selected_slots, token_final_scales = allgather( [ x, x_sf, token_selected_slots, token_final_scales, - token_selected_experts_for_statistic, ], self.mapping, dim=0, sizes=None if use_dp_padding else all_rank_num_tokens) x_row = x.shape[0] - # Fp4 gemm has extra scaling factor - if x_sf is not None: - x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) - - if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ): - if use_all_to_all: - if is_last_call: - gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( - (self.mapping.moe_ep_size, self.num_experts)) - self.layer_load_balancer.update_statistic( - gathered_loadbalancer_local_statistic_info) - else: - self.layer_load_balancer.statistic( - gathered_token_selected_experts_for_statistic, - is_first_stage=is_first_call, - is_last_stage=is_last_call) ep_size = self.ep_size ep_rank = self.ep_rank @@ -605,9 +597,6 @@ class WideEPMoE(MoE): x, x_sf, recv_topk_idx, token_final_scales) if x_sf is not None: x_sf = x_sf.view(x_sf_dtype) - if self.has_nvfp4: - x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, - self.scaling_vector_size) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: token_num = x_row hidden_size = x_col @@ -641,8 +630,6 @@ class WideEPMoE(MoE): x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1], x_sf.shape[2]) - x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, - self.scaling_vector_size) token_selected_slots = token_selected_slots.view(x.shape[0], 1) token_final_scales = torch.ones_like( token_selected_slots, dtype=token_final_scales.dtype) @@ -662,6 +649,7 @@ class WideEPMoE(MoE): output_dtype, quant_scales=quant_scales, input_sf=x_sf, + swizzled_input_sf=False, tp_size=self.tp_size, tp_rank=self.tp_rank, ep_size=ep_size, @@ -670,16 +658,15 @@ class WideEPMoE(MoE): cluster_rank=cluster_rank, enable_alltoall=use_all_to_all, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - use_w4a8_group_scaling=use_w4a8_group_scaling, + use_w4_group_scaling=use_w4_group_scaling, min_latency_mode=False, tune_max_num_tokens=self.tune_max_num_tokens, tuner_num_tokens=tuner_num_tokens, tuner_top_k=tuner_top_k, ) - if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ) and is_last_call: - self.layer_load_balancer.set_cpu_stage() + if self.layer_load_balancer and is_last_call: + self.layer_load_balancer.start_set_cpu_stage() # Only in cutlass_min_latency_mode, the output is a list of tensors. # Otherwise, the output should be unpacked as a single tensor. @@ -701,17 +688,23 @@ class WideEPMoE(MoE): final_hidden_states = final_hidden_states.view( self.expert_size_per_partition, num_tokens_per_expert_for_fused_moe, self.hidden_size) - final_hidden_states = self.deep_ep_buffer.low_latency_combine( - final_hidden_states, deep_ep_topk_idx, deep_ep_topk_weights, - deep_ep_handle) + if self.use_low_precision_combine: + global_scales = (448 * 6) / final_hidden_states.abs().max( + dim=-1, keepdim=True).values.to(torch.float32) + final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4( + final_hidden_states, global_scales, deep_ep_topk_idx, + deep_ep_topk_weights, deep_ep_handle) + else: + final_hidden_states = self.deep_ep_buffer.low_latency_combine( + final_hidden_states, deep_ep_topk_idx, + deep_ep_topk_weights, deep_ep_handle) else: raise NotImplementedError( f"Not available alltoall method type: {self.alltoall_method_type!r}" ) - if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ) and is_last_call: - self.layer_load_balancer.maybe_cudagraph_done_set_cpu_stage() + if self.layer_load_balancer and is_last_call: + self.layer_load_balancer.done_set_cpu_stage() return final_hidden_states @@ -939,10 +932,6 @@ class WideEPMoE(MoE): self.alltoall_workspace, self.ep_rank, self.ep_size) - if self.has_nvfp4: - x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, - self.scaling_vector_size) - return x, x_sf def alltoall_combine(self, final_hidden_states: torch.Tensor, diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index b90cdbe300..6301a84312 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -5,13 +5,18 @@ from typing import Dict, List, Optional import torch from torch import nn +from ...distributed.ops import reducescatter from ...model_config import ModelConfig from .routing import BaseMoeRoutingMethod class MoEWeightLoadingMode(Enum): + # Gate and up projection are not fused VANILLA = 0 + # Gate and up projection are fused FUSED_GATE_UP_PROJ = 1 + # Custom W4A8 weights from examples/quantization/quantize_mixed_precision_moe.py + W4A8_CUSTOM = 2 class MoE(nn.Module): @@ -23,7 +28,6 @@ class MoE(nn.Module): top_k (int): Number of top experts to select for each input token. hidden_size (int): Size of the hidden state. intermediate_size (int): Size of the intermediate state. - aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. @@ -41,6 +45,10 @@ class MoE(nn.Module): model_config: ModelConfig = ModelConfig(), weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, + bias: bool = False, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, ): from ...distributed import AllReduce @@ -50,9 +58,12 @@ class MoE(nn.Module): self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.weight_loading_mode = weight_loading_mode - + self.bias = bias self.dtype = dtype self.reduce_results = reduce_results + self.swiglu_alpha = swiglu_alpha + self.swiglu_beta = swiglu_beta + self.swiglu_limit = swiglu_limit # could be modified later self.quant_config = model_config.quant_config @@ -78,7 +89,8 @@ class MoE(nn.Module): self.intermediate_size_per_partition = intermediate_size // self.tp_size self.all_reduce = AllReduce(mapping=self.mapping, - strategy=model_config.allreduce_strategy) + strategy=model_config.allreduce_strategy, + dtype=self.dtype) @abstractmethod def create_weights(self): @@ -123,8 +135,47 @@ class MoE(nn.Module): return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4( ) + @property + def has_w4a8_mxfp4_fp8(self): + assert self._weights_created + return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8( + ) + + @property + def has_w4a8_mxfp4_mxfp8(self): + assert self._weights_created + return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8( + ) + + @property + def has_w4a16_mxfp4(self): + assert self._weights_created + return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a16_mxfp4( + ) + @property def enable_alltoall(self): """ enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter """ return False + + def reducescatter_or_allreduce( + self, + inputs, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ): + """ + Common helper for TP and EP in subclasses of the MoE module. + """ + outputs = inputs + if self.parallel_size > 1 and not self.enable_alltoall: + if self.use_dp: + outputs = reducescatter( + inputs, + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + elif self.reduce_results: + outputs = self.all_reduce(inputs) + return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index fff9ed6048..ff26c87687 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -9,10 +9,13 @@ from mpi4py import MPI import tensorrt_llm import tensorrt_llm.bindings.internal.runtime as _tbr -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping +from ...distributed import AllReduce +from ...utils import EventType +from ..multi_stream_utils import do_multi_stream + def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight: """ @@ -271,7 +274,8 @@ class SingleLayerMoeLoadBalancer: shared_mpi_comm: MPI.Comm, expert_count: int, updates_enabled: bool = True, - repeated_count=1): + repeated_count=1, + aux_stream: Optional[torch.cuda.Stream] = None): """ Initialize a SingleLayerMoeLoadBalancer instance. @@ -287,6 +291,7 @@ class SingleLayerMoeLoadBalancer: ) self.expert_count = expert_count self.updates_enabled = updates_enabled + self.repeated_count = repeated_count layer_id = self.single_layer_load_balancer_impl.get_layer_id() self.host_tensor_sharer = HostMoeTensorSharer( layer_id, expert_count, @@ -303,15 +308,34 @@ class SingleLayerMoeLoadBalancer: self.expert_count) self.load_expert_ids = list(range(load_expert_start, load_expert_end)) + if self.updates_enabled: + self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream( + ) + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeBalancer] + } + else: + self.aux_stream = None + self.event_dict = None + self.statistic_flag_tensor = None self.local_statistic_tensor = None - - self.cudagraph_stream = None - self.cudagraph_event = None - self.repeated_count = repeated_count - - self.statistic_stream = None - self.statistic_event = None + self.func_called_count = { + name: 0 + for name in [ + "start_wait_gpu_stage", + "done_wait_gpu_stage", + "start_set_cpu_stage", + "done_set_cpu_stage", + "update_local_statistic", + "get_local_statistic_tensor", + "update_statistic_with_gathered_statistic", + "update_statistic_with_local_ids", + "update_statistic_with_global_ids", + "route", + ] + } def get_layer_idx(self): return self.single_layer_load_balancer_impl.get_layer_id() @@ -441,139 +465,94 @@ class SingleLayerMoeLoadBalancer: self.host_tensor_sharer.finalize_host_tensor_sharing( self._add_host_weight_from_tensor) - def wait_for_gpu_stage(self) -> Optional[torch.Tensor]: + def start_wait_gpu_stage(self): """ - Wait for the GPU stage to complete. - - Returns: - A tensor indicating whether the stage is enabled + Start to wait for the GPU stage to complete. """ + assert self.func_called_count["start_wait_gpu_stage"] == 0 + self.func_called_count["start_wait_gpu_stage"] += 1 if self.updates_enabled: - assert self.statistic_flag_tensor is None, \ - "Already has statistic_flag_tensor, should not wait." - if is_graph_capturing(): - self.cudagraph_event = torch.cuda.Event() - self.cudagraph_stream = torch.cuda.Stream() - current_stream_event = torch.cuda.Event() - current_stream_event.record(torch.cuda.current_stream()) - with torch.cuda.stream(self.cudagraph_stream): - current_stream_event.wait() + if do_multi_stream(): + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage( self.single_layer_load_balancer_ptr) - self.cudagraph_event.record(self.cudagraph_stream) + self.event_dict[EventType.MoeBalancer].record() else: self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage( self.single_layer_load_balancer_ptr) - return self.statistic_flag_tensor - else: - return - def maybe_cudagraph_done_wait(self): + def done_wait_gpu_stage(self): + """ + Done waiting for the GPU stage to complete. + """ + assert self.func_called_count["start_wait_gpu_stage"] == 1 + assert self.func_called_count["done_wait_gpu_stage"] == 0 + self.func_called_count["done_wait_gpu_stage"] += 1 if self.updates_enabled: - if is_graph_capturing(): - assert self.cudagraph_event is not None, "should have cudagraph_event when capturing" - assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing" - self.cudagraph_event.wait() + if do_multi_stream(): + self.event_dict[EventType.MoeBalancer].wait() - def set_cpu_stage(self): + def start_set_cpu_stage(self): """ - Set the CPU stage. + Start to set the CPU stage. """ + assert self.func_called_count["done_wait_gpu_stage"] == 1 + assert self.func_called_count["start_set_cpu_stage"] == 0 + self.func_called_count["start_set_cpu_stage"] += 1 if self.updates_enabled: - assert self.statistic_flag_tensor is not None, \ - "Doesn't have statistic_flag_tensor, should not set_cpu_stage." - self.statistic_flag_tensor = None - if is_graph_capturing(): - assert self.cudagraph_stream is not None, "Doesn't have cudagraph_stream, should not set_cpu_stage." - assert self.statistic_event is not None - assert self.statistic_stream is not None - # wait statistic update done - current_stream_event = torch.cuda.Event() - current_stream_event.record(torch.cuda.current_stream()) - with torch.cuda.stream(self.cudagraph_stream): - self.statistic_event.wait() - current_stream_event.wait() + if do_multi_stream(): + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() torch.ops.trtllm.moe_load_balance_set_cpu_stage( self.single_layer_load_balancer_ptr) - self.cudagraph_event.record(self.cudagraph_stream) - self.statistic_event = None - self.statistic_stream = None + self.event_dict[EventType.MoeBalancer].record() else: torch.ops.trtllm.moe_load_balance_set_cpu_stage( self.single_layer_load_balancer_ptr) - def maybe_cudagraph_done_set_cpu_stage(self): - if self.updates_enabled: - if is_graph_capturing(): - assert self.cudagraph_event is not None, "should have cudagraph_event when capturing" - assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing" - self.cudagraph_event.wait() - self.cudagraph_stream = None - self.cudagraph_event = None - - def statistic(self, gathered_raw_expert_ids: torch.Tensor, - is_first_stage: bool, is_last_stage: bool): + def done_set_cpu_stage(self): """ - Perform statistics on the expert IDs. + Done setting the CPU stage. + """ + assert self.func_called_count["start_set_cpu_stage"] == 1 + for name in self.func_called_count: + self.func_called_count[name] = 0 + self.statistic_flag_tensor = None + if self.updates_enabled: + if do_multi_stream(): + self.event_dict[EventType.MoeBalancer].wait() + + def update_local_statistic(self, local_raw_expert_ids: torch.Tensor, + is_first_stage: bool, is_last_stage: bool): + """ + Update local statistics of the expert IDs. Args: - gathered_raw_expert_ids: The gathered raw expert IDs from all ranks + local_raw_expert_ids: The local raw expert IDs is_first_stage: Whether this is the first stage is_last_stage: Whether this is the last stage """ + assert self.func_called_count["done_wait_gpu_stage"] == 1 + assert self.func_called_count["update_statistic_with_global_ids"] == 0 + self.func_called_count["update_local_statistic"] += 1 if self.updates_enabled: - assert isinstance(self.statistic_flag_tensor, torch.Tensor) - if is_graph_capturing(): - if is_first_stage: - self.statistic_event = torch.cuda.Event() - self.statistic_stream = torch.cuda.Stream() - current_stream_event = torch.cuda.Event() - current_stream_event.record(torch.cuda.current_stream()) - with torch.cuda.stream(self.statistic_stream): - current_stream_event.wait() - torch.ops.trtllm.moe_load_balance_statistic( - gathered_raw_expert_ids, self.statistic_flag_tensor, - self.single_layer_load_balancer_ptr, is_first_stage, - is_last_stage) - self.statistic_event.record() - else: - torch.ops.trtllm.moe_load_balance_statistic( - gathered_raw_expert_ids, self.statistic_flag_tensor, - self.single_layer_load_balancer_ptr, is_first_stage, - is_last_stage) - - def local_statistic(self, local_raw_expert_ids: torch.Tensor, - is_first_stage: bool, is_last_stage: bool): - """ - Perform local statistics on the expert IDs. - - Args: - local_raw_expert_ids: The gathered raw expert IDs from all ranks - is_first_stage: Whether this is the first stage - is_last_stage: Whether this is the last stage - """ - if self.updates_enabled: - assert isinstance(self.statistic_flag_tensor, torch.Tensor) - if is_first_stage: - assert self.local_statistic_tensor is None + if self.local_statistic_tensor is None: self.local_statistic_tensor = torch.empty( (self.expert_count, ), dtype=torch.int32, device=torch.device('cuda')) - if is_graph_capturing(): - if is_first_stage: - self.statistic_event = torch.cuda.Event() - self.statistic_stream = torch.cuda.Stream() - current_stream_event = torch.cuda.Event() - current_stream_event.record(torch.cuda.current_stream()) - with torch.cuda.stream(self.statistic_stream): - current_stream_event.wait() + if do_multi_stream(): + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() torch.ops.trtllm.moe_hierarchical_statistic_local_device( local_raw_expert_ids, self.local_statistic_tensor, self.statistic_flag_tensor, self.single_layer_load_balancer_ptr, is_first_stage, is_last_stage) - self.statistic_event.record(self.statistic_stream) else: torch.ops.trtllm.moe_hierarchical_statistic_local_device( local_raw_expert_ids, self.local_statistic_tensor, @@ -581,48 +560,119 @@ class SingleLayerMoeLoadBalancer: self.single_layer_load_balancer_ptr, is_first_stage, is_last_stage) - def get_local_statistic_tensor(self): + def get_local_statistic_tensor(self) -> Optional[torch.Tensor]: """ - Get the local statistic tensor. Should perform allreduce on it and then call update_statistic + Get the local statistic tensor. Returns: The local statistic tensor if using statistic else None """ + assert self.func_called_count["update_local_statistic"] > 0 + self.func_called_count["get_local_statistic_tensor"] += 1 if self.updates_enabled: - assert self.local_statistic_tensor is not None - if is_graph_capturing(): - assert self.statistic_event is not None - assert self.statistic_stream is not None - self.statistic_event.wait() + if do_multi_stream(): + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.MoeBalancer].record() + self.event_dict[EventType.MoeBalancer].wait() return self.local_statistic_tensor return None - def update_statistic(self, gathered_local_statistic_tensor: torch.Tensor): + def update_statistic_with_gathered_statistic( + self, gathered_local_statistic_tensor: torch.Tensor): """ - Perform update with global statistics. + Update statistics of the expert IDs, using gathered local statistic tensors. Args: gathered_local_statistic_tensor: gathered local statistics info, should have shape (world_size, self.expert_count) """ - if self.updates_enabled: - assert isinstance(self.statistic_flag_tensor, torch.Tensor) + assert self.func_called_count["get_local_statistic_tensor"] > 0 + assert self.func_called_count["update_statistic_with_local_ids"] == 0 + assert self.func_called_count["update_statistic_with_global_ids"] == 0 + self.func_called_count["update_statistic_with_gathered_statistic"] += 1 - def _update_statistic(): - global_statistic_info = torch.sum( - gathered_local_statistic_tensor, dim=0, dtype=torch.int32) + def _update_statistic(): + global_statistic_info = torch.sum(gathered_local_statistic_tensor, + dim=0, + dtype=torch.int32) + torch.ops.trtllm.moe_hierarchical_statistic_update( + global_statistic_info, self.statistic_flag_tensor, + self.single_layer_load_balancer_ptr) + + if self.updates_enabled: + if do_multi_stream(): + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() + _update_statistic() + else: + _update_statistic() + + def update_statistic_with_local_ids(self, + local_raw_expert_ids: torch.Tensor, + is_first_stage: bool, + is_last_stage: bool, + allreduce: Optional[AllReduce] = None): + """ + Update statistics of the expert IDs, using local raw expert IDs. + + Args: + local_raw_expert_ids: The local raw expert IDs + is_first_stage: Whether this is the first stage + is_last_stage: Whether this is the last stage + allreduce: The allreduce object + """ + assert self.func_called_count["done_wait_gpu_stage"] == 1 + assert self.func_called_count[ + "update_statistic_with_gathered_statistic"] == 0 + assert self.func_called_count["update_statistic_with_global_ids"] == 0 + self.func_called_count["update_statistic_with_local_ids"] += 1 + + def _update_statistic(): + if is_last_stage: + global_statistic_info = allreduce(self.local_statistic_tensor) torch.ops.trtllm.moe_hierarchical_statistic_update( global_statistic_info, self.statistic_flag_tensor, self.single_layer_load_balancer_ptr) - if is_graph_capturing(): - current_stream_event = torch.cuda.Event() - current_stream_event.record(torch.cuda.current_stream()) - with torch.cuda.stream(self.statistic_stream): - current_stream_event.wait() + if self.updates_enabled: + self.update_local_statistic(local_raw_expert_ids, is_first_stage, + is_last_stage) + if do_multi_stream(): + with torch.cuda.stream(self.aux_stream): _update_statistic() - self.statistic_event.record(self.statistic_stream) else: _update_statistic() - self.local_statistic_tensor = None + + def update_statistic_with_global_ids(self, + gathered_raw_expert_ids: torch.Tensor, + is_first_stage: bool, + is_last_stage: bool): + """ + Update statistics of the expert IDs, using gathered raw expert IDs from all ranks. + + Args: + gathered_raw_expert_ids: The gathered raw expert IDs from all ranks + is_first_stage: Whether this is the first stage + is_last_stage: Whether this is the last stage + """ + assert self.func_called_count["done_wait_gpu_stage"] == 1 + assert self.func_called_count[ + "update_statistic_with_gathered_statistic"] == 0 + assert self.func_called_count["update_statistic_with_local_ids"] == 0 + self.func_called_count["update_statistic_with_global_ids"] += 1 + if self.updates_enabled: + if do_multi_stream(): + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() + torch.ops.trtllm.moe_load_balance_statistic( + gathered_raw_expert_ids, self.statistic_flag_tensor, + self.single_layer_load_balancer_ptr, is_first_stage, + is_last_stage) + else: + torch.ops.trtllm.moe_load_balance_statistic( + gathered_raw_expert_ids, self.statistic_flag_tensor, + self.single_layer_load_balancer_ptr, is_first_stage, + is_last_stage) def route(self, token_selected_experts: torch.Tensor, @@ -637,6 +687,8 @@ class SingleLayerMoeLoadBalancer: Returns: A tensor of routed slot IDs """ + assert self.func_called_count["done_wait_gpu_stage"] == 1 + self.func_called_count["route"] += 1 return torch.ops.trtllm.moe_load_balance_routing( token_selected_experts, offset_by_ep_rank, self.single_layer_load_balancer_ptr) @@ -731,8 +783,13 @@ class MoeLoadBalancer: assert repeated_count > 0, "repeat count must be greater than 0" self.next_layer_repeated_count = repeated_count - def add_layer(self, expert_count: int, top_k: int, - slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer: + def add_layer( + self, + expert_count: int, + top_k: int, + slot_count_per_rank: int, + aux_stream: Optional[torch.cuda.Stream] = None + ) -> SingleLayerMoeLoadBalancer: """ Add a new layer to the load balancer. @@ -740,6 +797,7 @@ class MoeLoadBalancer: expert_count: The number of experts in the layer top_k: The number of experts each token selects slot_count_per_rank: The number of slots per rank + aux_stream: The auxiliary stream for overlapping Returns: A SingleLayerMoeLoadBalancer instance for the new layer @@ -756,7 +814,8 @@ class MoeLoadBalancer: self.shared_mpi_comm, expert_count, updates_enabled=updates_enabled, - repeated_count=repeat_count) + repeated_count=repeat_count, + aux_stream=aux_stream) single_layer_load_balancer.set_shared_memory_base_name( self.shared_memory_base_name) self.single_layer_load_balancers.append(single_layer_load_balancer) @@ -792,8 +851,8 @@ class MoeLoadBalancer: """ self.load_balancer_impl.set_warm_up_iter_count(iter_count) - def set_next_iter_info(self, enable_statistic: Optional[bool], - enable_update_weights: Optional[bool]): + def set_iter_info(self, enable_statistic: Optional[bool], + enable_update_weights: Optional[bool]): if enable_statistic is not None: self.enable_statistic = enable_statistic if enable_update_weights is not None: @@ -939,8 +998,8 @@ class MoeLoadBalancerIterContext: """ if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing( ): - self.moe_load_balancer.set_next_iter_info(self.enable_statistic, - self.enable_updates) + self.moe_load_balancer.set_iter_info(self.enable_statistic, + self.enable_updates) self.moe_load_balancer.start_iter() return self @@ -988,8 +1047,11 @@ def moe_load_balancer_set_repeated_for_next_layer(repeat_count: int): def moe_load_balancer_add_single_layer( - expert_count: int, top_k: int, - slot_count_per_rank: int) -> Optional[SingleLayerMoeLoadBalancer]: + expert_count: int, + top_k: int, + slot_count_per_rank: int, + aux_stream: Optional[torch.cuda.Stream] = None +) -> Optional[SingleLayerMoeLoadBalancer]: """ Add a new layer to the current active MoeLoadBalancer. @@ -997,11 +1059,13 @@ def moe_load_balancer_add_single_layer( expert_count: The number of experts in the layer top_k: The number of experts each token selects slot_count_per_rank: The number of slots per rank + aux_stream: The auxiliary stream for overlapping Returns: A SingleLayerMoeLoadBalancer instance for the new layer, or None if not in a MoeLoadBalancer context """ load_balancer = get_moe_load_balancer() if load_balancer is not None: - return load_balancer.add_layer(expert_count, top_k, slot_count_per_rank) + return load_balancer.add_layer(expert_count, top_k, slot_count_per_rank, + aux_stream) return None diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 249aadc04e..734a7240e4 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1,11 +1,15 @@ +import math from abc import ABC, abstractmethod -from typing import Dict, List, NamedTuple, Union +from typing import Dict, List, NamedTuple, Optional, Union import torch +import torch.nn.functional as F from torch import nn -from tensorrt_llm import logger +import tensorrt_llm.logger as trtllm_logger from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.quantization.functional import \ + preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.utils.fp4_utils import ( float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices) @@ -20,8 +24,10 @@ from .interface import MoEWeightLoadingMode FUSED_MOE_NVFP4_INPUT_DTYPE = torch.int64 # pack weights into int64, e.g. 16 x nvfp4 weight values FUSED_MOE_NVFP4_WEIGHT_DTYPE = torch.int64 +FUSED_MOE_MXFP4_WEIGHT_DTYPE = torch.int64 # pack weight block scales into int32, e.g. 4 x fp8 weight values FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE = torch.int32 +FUSED_MOE_MXFP4_WEIGHT_BLOCK_SCALE_DTYPE = torch.int32 class FusedMoEQuantScalesFP8(NamedTuple): @@ -59,10 +65,96 @@ class FusedMoEQuantScalesW4A8(NamedTuple): alpha_2: torch.Tensor +class FusedMoEQuantScalesINT8WoqPerChannel(NamedTuple): + fc31_weight_scale: torch.Tensor + fc2_weight_scale: torch.Tensor + + +class FusedMoEQuantScalesW4A16MXFP4(NamedTuple): + scale_1_interleaved: torch.Tensor + scale_2_interleaved: torch.Tensor + + +class FusedMoEQuantScalesW4A8MXFP4FP8(NamedTuple): + fc31_weight_block_scale: torch.Tensor + fc31_dequant_scale: torch.Tensor + fc2_input_scale: torch.Tensor + fc2_weight_block_scale: torch.Tensor + fc2_dequant_scale: torch.Tensor + + +class FusedMoEQuantScalesW4A8MXFP4MXFP8(NamedTuple): + fc31_weight_block_scale: torch.Tensor + fc31_dequant_scale: torch.Tensor + fc2_weight_block_scale: torch.Tensor + fc2_dequant_scale: torch.Tensor + + +def trtllmgen_maybe_get_cached_w3_w1_permute_indices( + dst_w3_w1_weight: torch.Tensor, + cache_permute_indices: Dict[tuple[tuple[int, int, int], str], + torch.Tensor], + epilogue_tile_m: int, + num_elts_per_sf: Union[None, int] = None) -> torch.Tensor: + key = (dst_w3_w1_weight.shape, "w31") + if key not in cache_permute_indices: + # Get permute indices and chain them together + permute0 = get_reorder_rows_for_gated_act_gemm_row_indices( + dst_w3_w1_weight) + if num_elts_per_sf is None: + permute1 = get_shuffle_matrix_a_row_indices( + dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m) + else: + permute1 = get_shuffle_matrix_sf_a_row_indices( + dst_w3_w1_weight, + epilogue_tile_m=epilogue_tile_m, + num_elts_per_sf=num_elts_per_sf) + # Memoize permute indices as recompute is **very** costly + cache_permute_indices[key] = permute0[permute1].to( + dst_w3_w1_weight.device) + permute_indices = cache_permute_indices[key] + return permute_indices + + +def trtllmgen_maybe_get_cached_w2_permute_indices( + dst_w2_weight: torch.Tensor, + cache_permute_indices: Dict[tuple[tuple[int, int, int], str], + torch.Tensor], + epilogue_tile_m: int, + num_elts_per_sf: Union[None, int] = None) -> torch.Tensor: + key = (dst_w2_weight.shape, "w2") + if key not in cache_permute_indices: + if num_elts_per_sf is None: + permute_indices = (get_shuffle_matrix_a_row_indices( + dst_w2_weight, epilogue_tile_m).to(dst_w2_weight.device)) + else: + permute_indices = (get_shuffle_matrix_sf_a_row_indices( + dst_w2_weight, + epilogue_tile_m=epilogue_tile_m, + num_elts_per_sf=num_elts_per_sf).to(dst_w2_weight.device)) + # Memoize permute indices as recompute is **very** costly + cache_permute_indices[key] = permute_indices + permute_indices = cache_permute_indices[key] + return permute_indices + + +def maybe_pad_for_mxfp4(weight: torch.Tensor, + col_alignment: int, + row_alignment: Optional[int] = None) -> torch.Tensor: + col_pad_size = (col_alignment - weight.shape[-1]) % col_alignment + if row_alignment: + row_pad_size = (row_alignment - weight.shape[-2]) % row_alignment + weight = F.pad(weight, (0, col_pad_size, 0, row_pad_size)) + else: + weight = F.pad(weight, (0, col_pad_size)) + return weight + + class FusedMoEMethodBase(ABC): """ Base class for all fused MoE methods. """ + weight_alignment: int = 1 def need_load_shared_weights(self, module): if hasattr( @@ -72,9 +164,16 @@ class FusedMoEMethodBase(ABC): return True return False - def create_weights(self, module: torch.nn.Module, weight_dtype: torch.dtype, - w3_w1_weight_shape: tuple[int, int, int], - w2_weight_shape: tuple[int, int, int]): + def create_weights( + self, + module: torch.nn.Module, + weight_dtype: torch.dtype, + w3_w1_weight_shape: tuple[int, int, int], + w2_weight_shape: tuple[int, int, int], + bias_dtype: Optional[torch.dtype] = None, + w3_w1_bias_shape: Optional[tuple[int, int]] = None, + w2_bias_shape: Optional[tuple[int, int]] = None, + ): # Fused gate_up_proj (column parallel) w3_w1_weight = nn.Parameter(torch.empty(w3_w1_weight_shape, dtype=weight_dtype), @@ -87,28 +186,61 @@ class FusedMoEMethodBase(ABC): requires_grad=False) module.register_parameter("w2_weight", w2_weight) - def load_expert_weights_to_dst(self, module: torch.nn.Module, - weights: List[Dict], - weight_loading_mode: MoEWeightLoadingMode, - load_expert_ids: List[int], - dst_w3_w1_weights_tensor: torch.Tensor, - dst_w2_weights_tensor: torch.Tensor): + # bias + if module.bias: + if w3_w1_bias_shape is None: + w3_w1_bias_shape = (module.expert_size_per_partition, + module.intermediate_size_per_partition * 2) + if w2_bias_shape is None: + w2_bias_shape = (module.expert_size_per_partition, + module.hidden_size) + bias_dtype = bias_dtype or module.dtype + w3_w1_bias = nn.Parameter(torch.empty(w3_w1_bias_shape, + dtype=bias_dtype), + requires_grad=False) + module.register_parameter("w3_w1_bias", w3_w1_bias) + + w2_bias = nn.Parameter(torch.empty(w2_bias_shape, dtype=bias_dtype), + requires_grad=False) + module.register_parameter("w2_bias", w2_bias) + else: + module.w3_w1_bias = None + module.w2_bias = None + + def load_expert_weights_to_dst( + self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + load_expert_ids: List[int], dst_w3_w1_weights_tensor: torch.Tensor, + dst_w2_weights_tensor: torch.Tensor, + dst_w3_w1_bias_tensor: Optional[torch.Tensor], + dst_w2_bias_tensor: Optional[torch.Tensor]): # Multithread weight load is superseded by prefetch_files() in model_engine.py # Also, threading adds overhead in order to protect shuffle index cache with critical section. for local_slot_id, expert_id in enumerate(load_expert_ids): # expert_idx is the local slot index of current rank expert_idx = local_slot_id - if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + if weight_loading_mode in [ + MoEWeightLoadingMode.VANILLA, + MoEWeightLoadingMode.W4A8_CUSTOM + ]: w1_weight = weights[f"{expert_id}.w1.weight"] w3_weight = weights[f"{expert_id}.w3.weight"] w2_weight = weights[f"{expert_id}.w2.weight"] + if module.bias: + w1_bias = weights[f"{expert_id}.w1.bias"] + w3_bias = weights[f"{expert_id}.w3.bias"] + w2_bias = weights[f"{expert_id}.w2.bias"] elif weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( 0, 1) w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0) w2_weight = weights["down_proj"][expert_id].transpose( 0, 1).contiguous() + if module.bias: + w1_w3_bias = weights["gate_up_proj.bias"][expert_id] + w1_bias, w3_bias = w1_w3_bias.chunk(2, dim=0) + w2_bias = weights["down_proj.bias"][expert_id] else: raise NotImplementedError( f"Unknown weight loading mode in MoE: {weight_loading_mode}" @@ -120,13 +252,23 @@ class FusedMoEMethodBase(ABC): self.load_expert_w2_weight(module, w2_weight, dst_w2_weights_tensor[expert_idx]) + if module.bias: + self.load_expert_w3_w1_weight( + module, w1_bias, w3_bias, + dst_w3_w1_bias_tensor.data[expert_idx]) + + self.load_expert_w2_weight(module, w2_bias, + dst_w2_bias_tensor.data[expert_idx]) + def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): - self.load_expert_weights_to_dst(module, weights, weight_loading_mode, - module.initial_local_expert_ids, - module.w3_w1_weight.data, - module.w2_weight.data) + self.load_expert_weights_to_dst( + module, weights, weight_loading_mode, + module.initial_local_expert_ids, module.w3_w1_weight.data, + module.w2_weight.data, + module.w3_w1_bias.data if module.bias else None, + module.w2_bias.data if module.bias else None) self.load_quant_scales(module, weights) # Re-setup quant scales after loading weights as the tensors may have been modified. @@ -145,17 +287,33 @@ class FusedMoEMethodBase(ABC): module.w2_weight.data.shape[1:], dtype=module.w2_weight.data.dtype, device='cpu') - self.load_expert_weights_to_dst(module, weights, - weight_loading_mode, - local_shared_load_expert_ids, - local_shared_w3_w1_tensors, - local_shared_w2_tensors) - module.register_all_parameter_slot_and_to_fix_weight_fns({ - 'w3_w1_weight': - local_shared_w3_w1_tensors, - 'w2_weight': - local_shared_w2_tensors - }) + if module.bias: + local_shared_w3_w1_bias_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w3_w1_bias.data.shape[1:], + dtype=module.w3_w1_bias.data.dtype, + device='cpu') + local_shared_w2_bias_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w2_bias.data.shape[1:], + dtype=module.w2_bias.data.dtype, + device='cpu') + self.load_expert_weights_to_dst( + module, weights, weight_loading_mode, + local_shared_load_expert_ids, local_shared_w3_w1_tensors, + local_shared_w2_tensors, + local_shared_w3_w1_bias_tensors if module.bias else None, + local_shared_w2_bias_tensors if module.bias else None) + weight_fns = { + 'w3_w1_weight': local_shared_w3_w1_tensors, + 'w2_weight': local_shared_w2_tensors + } + if module.bias: + weight_fns.update({ + 'w3_w1_bias': local_shared_w3_w1_bias_tensors, + 'w2_bias': local_shared_w2_bias_tensors + }) + module.register_all_parameter_slot_and_to_fix_weight_fns(weight_fns) module.layer_load_balancer.host_tensor_sharer.finalize_layer_weights( ) @@ -179,13 +337,13 @@ class FusedMoEMethodBase(ABC): Due to the special handling of slot_start and slot_end, we require the subclasses to implement this method or explicitly raise NotImplementedError. """ - raise NotImplementedError + # TODO: remove this method, it's no longer needed def apply(self, module: torch.nn.Module, input: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Apply the quantization method to the input tensor. - This isn’t necessary for all quantization methods, but it’s useful for + This isn't necessary for all quantization methods, but it's useful for certain backends that can encapsulate the MoE forward function. """ raise NotImplementedError @@ -199,12 +357,18 @@ class FusedMoEMethodBase(ABC): Load w1 and w3 weights for each expert. Override this method if you need to preprocess the weights differently. """ - w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w3_w1_weight.device + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) - w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) + TensorParallelMode.COLUMN, + device=device) w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), @@ -218,9 +382,13 @@ class FusedMoEMethodBase(ABC): Load w2 weight for each expert. Override this method if you need to preprocess the weights differently. """ - w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w2_weight.device + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.ROW) + TensorParallelMode.ROW, + device=device) dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), non_blocking=True) @@ -249,6 +417,79 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): return tuple() +def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale, + dst_fc31_input_scale: torch.Tensor): + dst_fc31_input_scale.copy_( + max(w1_input_scale[...].reshape([]), w3_input_scale[...].reshape([]))) + + +def load_expert_fc2_input_scale_fp8_qdq(w2_input_scale, + dst_fc2_input_scale: torch.Tensor): + dst_fc2_input_scale.copy_(w2_input_scale[...].reshape([])) + + +def load_activation_scales_fp8_qdq(module: torch.nn.Module, weights: Dict): + tmp_fc31_input_scale = torch.empty(module.num_experts, dtype=torch.float32) + tmp_fc2_input_scale = torch.empty(module.num_experts, dtype=torch.float32) + for expert_id in range(module.num_experts): + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_input_scale = weights[f"{expert_id}.w1.input_scale"] + w3_input_scale = weights[f"{expert_id}.w3.input_scale"] + w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_input_scale = weights[f"gate_up_proj_input_scale"] + w3_input_scale = weights[f"gate_up_proj_input_scale"] + w2_input_scale = weights[f"down_proj_input_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + + load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale, + tmp_fc31_input_scale[expert_id]) + + load_expert_fc2_input_scale_fp8_qdq(w2_input_scale, + tmp_fc2_input_scale[expert_id]) + + # max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales. + # It's used to quantize fc31 input inside the MOE op + max_fc31_input_scale = tmp_fc31_input_scale.max() + # max_fc2_input_scale is the maximum of all w2 input scales. + max_fc2_input_scale = tmp_fc2_input_scale.max() + + return max_fc31_input_scale, max_fc2_input_scale + + +def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module, + w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight: torch.Tensor): + w1_weight_scale = w1_weight_scale[...].reshape([]) + w3_weight_scale = w3_weight_scale[...].reshape([]) + max_w3_w1_weight_scale = max(w1_weight_scale, w3_weight_scale) + + w3_weight = dst_w3_w1_weight.narrow( + dim=0, start=0, + length=module.intermediate_size_per_partition).to(dtype=module.dtype) + w1_weight = dst_w3_w1_weight.narrow( + dim=0, + start=module.intermediate_size_per_partition, + length=module.intermediate_size_per_partition).to(dtype=module.dtype) + dequant_w3_weight = w3_weight * w3_weight_scale + dequant_w1_weight = w1_weight * w1_weight_scale + requant_w3_weight = (dequant_w3_weight / max_w3_w1_weight_scale).to( + torch.float8_e4m3fn) + requant_w1_weight = (dequant_w1_weight / max_w3_w1_weight_scale).to( + torch.float8_e4m3fn) + + dst_w3_w1_weight.narrow( + dim=0, start=0, + length=module.intermediate_size_per_partition).copy_(requant_w3_weight) + dst_w3_w1_weight.narrow( + dim=0, + start=module.intermediate_size_per_partition, + length=module.intermediate_size_per_partition).copy_(requant_w1_weight) + + class FP8QDQFusedMoEMethod(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): @@ -302,17 +543,6 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase): fc1_input_dequant=module.fc31_input_dequant, ) - def load_expert_fc31_input_scale_fp8_qdq( - self, w1_input_scale, w3_input_scale, - dst_fc31_input_scale: torch.Tensor): - dst_fc31_input_scale.copy_( - max(w1_input_scale[...].reshape([]), - w3_input_scale[...].reshape([]))) - - def load_expert_fc2_input_scale_fp8_qdq(self, w2_input_scale, - dst_fc2_input_scale: torch.Tensor): - dst_fc2_input_scale.copy_(w2_input_scale[...].reshape([])) - def load_expert_w3_w1_weight_scale_fp8_qdq( self, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale: torch.Tensor): @@ -320,73 +550,14 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase): w3_weight_scale = w3_weight_scale[...].reshape([]) dst_w3_w1_weight_scale.copy_(max(w1_weight_scale, w3_weight_scale)) - def requantize_expert_w3_w1_weight_fp8_qdq(self, module: torch.nn.Module, - w1_weight_scale, w3_weight_scale, - dst_w3_w1_weight: torch.Tensor): - w1_weight_scale = w1_weight_scale[...].reshape([]) - w3_weight_scale = w3_weight_scale[...].reshape([]) - max_w3_w1_weight_scale = max(w1_weight_scale, w3_weight_scale) - - w3_weight = dst_w3_w1_weight.narrow( - dim=0, start=0, length=module.intermediate_size_per_partition).to( - dtype=module.dtype) - w1_weight = dst_w3_w1_weight.narrow( - dim=0, - start=module.intermediate_size_per_partition, - length=module.intermediate_size_per_partition).to( - dtype=module.dtype) - dequant_w3_weight = w3_weight * w3_weight_scale - dequant_w1_weight = w1_weight * w1_weight_scale - requant_w3_weight = (dequant_w3_weight / max_w3_w1_weight_scale).to( - torch.float8_e4m3fn) - requant_w1_weight = (dequant_w1_weight / max_w3_w1_weight_scale).to( - torch.float8_e4m3fn) - - dst_w3_w1_weight.narrow( - dim=0, start=0, - length=module.intermediate_size_per_partition).copy_( - requant_w3_weight) - dst_w3_w1_weight.narrow( - dim=0, - start=module.intermediate_size_per_partition, - length=module.intermediate_size_per_partition).copy_( - requant_w1_weight) - def load_expert_w2_weight_scale_fp8(self, w2_weight_scale, dst_w2_weight_scale: torch.Tensor): dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([])) def load_quant_scales(self, module: torch.nn.Module, weights: Dict): # Step1: Load input scales. - tmp_fc31_input_scale = torch.empty(module.num_experts, - dtype=torch.float32) - tmp_fc2_input_scale = torch.empty(module.num_experts, - dtype=torch.float32) - for expert_id in range(module.num_experts): - if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_input_scale = weights[f"{expert_id}.w1.input_scale"] - w3_input_scale = weights[f"{expert_id}.w3.input_scale"] - w2_input_scale = weights[f"{expert_id}.w2.input_scale"] - elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: - w1_input_scale = weights[f"gate_up_proj_input_scale"] - w3_input_scale = weights[f"gate_up_proj_input_scale"] - w2_input_scale = weights[f"down_proj_input_scale"] - else: - raise NotImplementedError( - f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" - ) - - self.load_expert_fc31_input_scale_fp8_qdq( - w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id]) - - self.load_expert_fc2_input_scale_fp8_qdq( - w2_input_scale, tmp_fc2_input_scale[expert_id]) - - # max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales. - # It's used to quantize fc31 input inside the MOE op - max_fc31_input_scale = tmp_fc31_input_scale.max() - # max_fc2_input_scale is the maximum of all w2 input scales. - max_fc2_input_scale = tmp_fc2_input_scale.max() + max_fc31_input_scale, max_fc2_input_scale = load_activation_scales_fp8_qdq( + module, weights) # Step2: Load weight scales and requantize w3_w1_weight. tmp_w3_w1_weight_scale = torch.empty(module.expert_size_per_partition, @@ -415,7 +586,7 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase): w1_weight_scale, w3_weight_scale, tmp_w3_w1_weight_scale[expert_idx]) - self.requantize_expert_w3_w1_weight_fp8_qdq( + requantize_expert_w3_w1_weight_fp8_qdq( module, w1_weight_scale, w3_weight_scale, module.w3_w1_weight.data[expert_idx]) @@ -468,45 +639,8 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): - - if get_sm_version() == 100: - expert_ids = set(module.initial_local_expert_ids) - if self.need_load_shared_weights(module): - expert_ids.update( - module.layer_load_balancer.get_load_expert_ids()) - for name in list(weights.keys()): - if name.endswith("weight_scale_inv"): - if int(name.split(".")[0]) not in expert_ids: - continue - weight_name = name.replace("weight_scale_inv", "weight") - logger.debug(f"Resmoothing {weight_name}") - weight = weights[weight_name][:] - scale = weights[name][:] - weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( - weight, scale) super().load_weights(module, weights, weight_loading_mode) - if get_sm_version() == 100: - transfromed_w3_w1_scale = transform_sf_into_required_layout( - module.quant_scales[0], - mn=module.w3_w1_weight.shape[1], - k=module.w3_w1_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - module.w3_w1_weight_scaling_factor = nn.Parameter( - transfromed_w3_w1_scale, requires_grad=False) - transfromed_w2_scale = transform_sf_into_required_layout( - module.quant_scales[1], - mn=module.w2_weight.shape[1], - k=module.w2_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, - requires_grad=False) - self.setup_quant_scales(module) - def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales( fc_weight_scales=module.w3_w1_weight_scaling_factor, @@ -527,6 +661,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): self, module: torch.nn.Module, weights: Dict, load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor, device): + assert device.type == "cuda" for local_slot_id, expert_id in enumerate(load_expert_ids): if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: w3_scale = weights['gate_up_proj_weight_scale'][ @@ -603,6 +738,190 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): }) +class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( + DeepSeekFP8BlockScalesFusedMoEMethod): + + def load_weights(self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode): + if get_sm_version() == 100: + expert_ids = set(module.initial_local_expert_ids) + if self.need_load_shared_weights(module): + expert_ids.update( + module.layer_load_balancer.get_load_expert_ids()) + for name in list(weights.keys()): + if name.endswith("weight_scale_inv"): + if int(name.split(".")[0]) not in expert_ids: + continue + weight_name = name.replace("weight_scale_inv", "weight") + trtllm_logger.logger.debug(f"Resmoothing {weight_name}") + weight = weights[weight_name][:] + scale = weights[name][:] + weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( + weight, scale) + super().load_weights(module, weights, weight_loading_mode) + + if get_sm_version() == 100: + transfromed_w3_w1_scale = transform_sf_into_required_layout( + module.quant_scales[0], + mn=module.w3_w1_weight.shape[1], + k=module.w3_w1_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w3_w1_weight_scaling_factor = nn.Parameter( + transfromed_w3_w1_scale, requires_grad=False) + transfromed_w2_scale = transform_sf_into_required_layout( + module.quant_scales[1], + mn=module.w2_weight.shape[1], + k=module.w2_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, + requires_grad=False) + self.setup_quant_scales(module) + + +class INT8WoqPerChannelFusedMoEMethod(FusedMoEMethodBase): + + def create_weights(self, module: torch.nn.Module): + module.sm_version = get_sm_version() + module.sm_version = 80 if module.sm_version >= 90 else module.sm_version + module.preprocessor = preprocess_weights_for_mixed_gemm + + weight_dtype = torch.int8 + if not module.quant_config.layer_quant_mode.is_int8_weight_only(): + raise NotImplementedError( + f"Weight Only Quantization currently only supports INT8. Got: {module.quant_config.layer_quant_mode}." + ) + + # notice the weight shape for int8 weight-only is different from the original shape, + # since the quantized weights have their own layout + w3_w1_weight_shape = (module.expert_size_per_partition, + module.hidden_size, + module.intermediate_size_per_partition * 2) + w2_weight_shape = (module.expert_size_per_partition, + module.intermediate_size_per_partition, + module.hidden_size) + + fc31_weight_scale = nn.Parameter(torch.empty( + module.expert_size_per_partition, + module.intermediate_size_per_partition * 2, + dtype=module.dtype), + requires_grad=False) + module.register_parameter("fc31_weight_scale", fc31_weight_scale) + + fc2_weight_scale = nn.Parameter(torch.empty( + module.expert_size_per_partition, + module.hidden_size, + dtype=module.dtype), + requires_grad=False) + module.register_parameter("fc2_weight_scale", fc2_weight_scale) + + super().create_weights(module, weight_dtype, w3_w1_weight_shape, + w2_weight_shape) + self.setup_quant_scales(module) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = FusedMoEQuantScalesINT8WoqPerChannel( + fc31_weight_scale=module.fc31_weight_scale, + fc2_weight_scale=module.fc2_weight_scale, + ) + + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + assert module.smart_router + return FusedMoEQuantScalesINT8WoqPerChannel( + fc31_weight_scale=module.fc31_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc2_weight_scale=module.fc2_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + ) + + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + """ + Load w1 and w3 weights for each expert. + """ + w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN) + w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN) + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + + weight_dtype = torch.int8 + + assert module.dtype in [torch.float16, torch.bfloat16], \ + f"activation dtype should be float16 or bfloat16, got {module.dtype}" + if not module.quant_config.layer_quant_mode.is_int8_weight_only(): + raise NotImplementedError( + f"weight dtype should be INT8. Got: {module.quant_config.layer_quant_mode}." + ) + # preprocess the weights for mixed gemm + w31_weight_shard = module.preprocessor(w31_weight_shard.T.contiguous(), + weight_dtype, module.dtype, + module.sm_version).contiguous() + dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), + non_blocking=True) + + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + """ + Load w2 weight for each expert. + """ + w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.ROW) + + weight_dtype = torch.int8 + if not module.quant_config.layer_quant_mode.is_int8_weight_only(): + raise NotImplementedError( + f"Weight Only Quantization currently only supports INT8. Got: {module.quant_config.layer_quant_mode}." + ) + + # preprocess the weights for mixed gemm + w2_weight_shard = module.preprocessor(w2_weight_shard.T.contiguous(), + weight_dtype, module.dtype, + module.sm_version).contiguous() + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), + non_blocking=True) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # fc31 scales + all_w3_scales = [ + load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], + module.tp_size, module.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in module.initial_local_expert_ids + ] + all_w1_scales = [ + load_weight_shard(weights[f"{expert_id}.w1.weight_scale"], + module.tp_size, module.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in module.initial_local_expert_ids + ] + w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-1) + w3_w1_scales = w3_w1_scales.to(module.dtype) + module.fc31_weight_scale.data.copy_(w3_w1_scales.contiguous()) + + # fc2 scales + all_w2_scales = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], + module.tp_size, module.tp_rank, + TensorParallelMode.ROW) + for expert_id in module.initial_local_expert_ids + ] + w2_scales = torch.stack(all_w2_scales).to(module.dtype) + module.fc2_weight_scale.data.copy_(w2_scales.contiguous()) + + class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): @@ -634,6 +953,7 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): w2_weight_shape = (module.expert_size_per_partition, module.hidden_size, module.intermediate_size_per_partition // 2) + # Multiply act with reciprocal of per-channel pre_quant_scale * per-tensor input_scale fc31_act_scale = nn.Parameter(torch.empty(1, module.hidden_size, dtype=module.dtype), @@ -664,6 +984,7 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): requires_grad=False) module.register_parameter("fc2_weight_scale", fc2_weight_scale) + # Multiply W@X with per-tensor weight_scale_2 * per-tensor input_scale. fc31_alpha = nn.Parameter(torch.empty(module.expert_size_per_partition, 1, dtype=torch.float32), @@ -720,20 +1041,27 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): Load w1 and w3 weights for each expert. Override this method if you need to preprocess the weights differently. """ - w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, + device = dst_w3_w1_weight.device + self.device = device + assert device.type == "cuda" + + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) - w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) + TensorParallelMode.COLUMN, + device=device) w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + # SM89 if module.sm_version == 89: - import tensorrt_llm.quantization.functional as trtllm_f - - preprocessor = trtllm_f.preprocess_weights_for_mixed_gemm - packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 - unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + preprocessor = preprocess_weights_for_mixed_gemm w31_weight_shard = packer( unpacker(w31_weight_shard.cpu()).T.contiguous()).to( @@ -741,6 +1069,24 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): w31_weight_shard = preprocessor(w31_weight_shard, torch.quint4x2, torch.float8_e4m3fn, 89).view(dst_w3_w1_weight.shape) + # SM90 ModelOpt quantized weights + elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + # Original: [(N//2)*I4x2, K] which is two int4 elts in output dim packed into one + # Transpose: [K, (N//2)*I4x2] + transposed = w31_weight_shard.cpu().T.contiguous() + # Unpack: [K, N*I8] + unpacked = unpacker(transposed.view(torch.int8)) + # Transpose: [N, K*I8] + transposed = unpacked.T.contiguous() + # Pack: [N, (K//2)*I4x2] + w31_weight_shard = packer(transposed) + elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: + pass + else: + raise NotImplementedError( + f"Unsupported configuration: SM{module.sm_version} and {module.weight_loading_mode}." + ) + dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), non_blocking=True) @@ -751,16 +1097,18 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): Load w2 weight for each expert. Override this method if you need to preprocess the weights differently. """ - w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, + device = dst_w2_weight.device + assert device.type == "cuda" + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.ROW) + TensorParallelMode.ROW, + device=device) + packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 if module.sm_version == 89: - import tensorrt_llm.quantization.functional as trtllm_f - - preprocessor = trtllm_f.preprocess_weights_for_mixed_gemm - packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 - unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + preprocessor = preprocess_weights_for_mixed_gemm w2_weight_shard = packer( unpacker(w2_weight_shard.cpu()).T.contiguous()).to( @@ -768,40 +1116,120 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): w2_weight_shard = preprocessor(w2_weight_shard, torch.quint4x2, torch.float8_e4m3fn, 89).view(dst_w2_weight.shape) - + # SM90 ModelOpt quantized weights + elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + # Original: [(N//2)*I4x2, K] which is two int4 elts in output dim packed into one + # Transpose: [K, (N//2)*I4x2] + transposed = w2_weight_shard.cpu().T.contiguous() + # Unpack: [K, N*I8] + unpacked = unpacker(transposed.view(torch.int8)) + # Transpose: [N, K*I8] + transposed = unpacked.T.contiguous() + # Pack: [N, (K//2)*I4x2] + w2_weight_shard = packer(transposed) + elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: + pass + else: + raise NotImplementedError( + f"Unsupported configuration: SM{module.sm_version} and {module.weight_loading_mode}." + ) dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), non_blocking=True) def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + assert self.device.type == "cuda" + # fc31 scales + w4a8_custom = module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM + if w4a8_custom: + weight_scale_name = "weight_scale_inv" + else: + weight_scale_name = "weight_scale" + assert (len(module.interleave) == 2) + # fc31 scales all_w3_input_scales = [ - load_weight_shard(weights[f"{expert_id}.w3.input_scale"]) + load_weight_shard(weights[f"{expert_id}.w3.input_scale"], + device=self.device) for expert_id in module.initial_local_expert_ids ] all_w1_input_scales = [ - load_weight_shard(weights[f"{expert_id}.w1.input_scale"]) + load_weight_shard(weights[f"{expert_id}.w1.input_scale"], + device=self.device) for expert_id in module.initial_local_expert_ids ] all_w3_w1_input_scales_max = torch.max( torch.stack(all_w3_input_scales), torch.stack(all_w1_input_scales)).max() - module.fc31_act_scale.data.copy_( - torch.ones_like(module.fc31_act_scale) * - (1 / all_w3_w1_input_scales_max)) - module.fc31_alpha.data.copy_((torch.ones_like(module.fc31_alpha) * - all_w3_w1_input_scales_max).float()) + if w4a8_custom: + # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale + module.fc31_act_scale.data.copy_( + torch.ones_like(module.fc31_act_scale, device=self.device) * + (1 / all_w3_w1_input_scales_max)) + module.fc31_alpha.data.copy_( + (torch.ones_like(module.fc31_alpha, device=self.device) * + all_w3_w1_input_scales_max).float()) + else: + # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored + all_w3_pre_quant_scales = [ + load_weight_shard(weights[f"{expert_id}.w3.pre_quant_scale"], + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w1_pre_quant_scales = [ + load_weight_shard(weights[f"{expert_id}.w1.pre_quant_scale"], + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w3_w1_pre_quant_scales_max = torch.max( + torch.stack(all_w3_pre_quant_scales + + all_w1_pre_quant_scales).to(module.dtype), + dim=0, + ).values + module.fc31_act_scale.data.copy_( + torch.ones_like(module.fc31_act_scale, device=self.device) * + (all_w3_w1_pre_quant_scales_max) * + (1 / all_w3_w1_input_scales_max)) + # In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored + all_w3_weight_scale_2 = [ + load_weight_shard(weights[f"{expert_id}.w3.weight_scale_2"], + device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w1_weight_scale_2 = [ + load_weight_shard(weights[f"{expert_id}.w1.weight_scale_2"], + device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w3_w1_weight_scale_2_max = torch.max( + torch.stack(all_w3_weight_scale_2 + all_w1_weight_scale_2).to( + module.dtype), + dim=0, + ).values + module.fc31_alpha.data.copy_(all_w3_w1_weight_scale_2_max.float() * + all_w3_w1_input_scales_max.float()) + # Per-group weight_scale all_w3_scales = [ - load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"], - module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) + load_weight_shard(weights[f"{expert_id}.w3.{weight_scale_name}"], + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=self.device) for expert_id in module.initial_local_expert_ids ] all_w1_scales = [ - load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"], - module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) + load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"], + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=self.device) for expert_id in module.initial_local_expert_ids ] all_w3_w1_scales = torch.cat( @@ -812,6 +1240,8 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): else: w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view( module.dtype) + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w3_w1_scales /= all_w3_w1_weight_scale_2_max.float() w3_w1_s_shape = w3_w1_scales.shape w3_w1_scales_interleaved = w3_w1_scales.reshape( w3_w1_s_shape[0], w3_w1_s_shape[1], @@ -825,21 +1255,57 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): # fc2 scales all_w2_input_scales = [ - load_weight_shard(weights[f"{expert_id}.w2.input_scale"]) + load_weight_shard(weights[f"{expert_id}.w2.input_scale"], + device=self.device) for expert_id in module.initial_local_expert_ids ] all_w2_input_scales_max = torch.stack(all_w2_input_scales).to( module.dtype).max() - module.fc2_act_scale.data.copy_( - torch.ones_like(module.fc2_act_scale) * - (1 / all_w2_input_scales_max)) - module.fc2_alpha.data.copy_((torch.ones_like(module.fc2_alpha) * - all_w2_input_scales_max).float()) + if w4a8_custom: + # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale + module.fc2_act_scale.data.copy_( + torch.ones_like(module.fc2_act_scale, device=self.device) * + (1 / all_w2_input_scales_max)) + # In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha + module.fc2_alpha.data.copy_( + (torch.ones_like(module.fc2_alpha, device=self.device) * + all_w2_input_scales_max).float()) + else: + # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored + all_w2_pre_quant_scales = [ + load_weight_shard(weights[f"{expert_id}.w2.pre_quant_scale"], + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w2_pre_quant_scales_max = torch.max( + torch.stack(all_w2_pre_quant_scales).to(module.dtype), + dim=0).values + module.fc2_act_scale.data.copy_( + torch.ones_like(module.fc2_act_scale, device=self.device) * + (all_w2_pre_quant_scales_max.unsqueeze(-1)) * + (1 / all_w2_input_scales_max)) + # In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored + all_w2_weight_scale_2 = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale_2"], + device=self.device) + for expert_id in module.initial_local_expert_ids + ] + all_w2_weight_scale_2_max = torch.stack(all_w2_weight_scale_2).to( + module.dtype).max() + module.fc2_alpha.data.copy_(all_w2_weight_scale_2_max.float() * + all_w2_input_scales_max.float()) + + # Per-group weight_scale all_w2_scales = [ - load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"], - module.tp_size, module.tp_rank, - TensorParallelMode.ROW) + load_weight_shard(weights[f"{expert_id}.w2.{weight_scale_name}"], + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=self.device) for expert_id in module.initial_local_expert_ids ] if module.sm_version == 89: @@ -848,6 +1314,224 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): else: w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view( module.dtype) + + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w2_scales /= all_w2_weight_scale_2_max.float() + w2_s_shape = w2_scales.shape + w2_scales_interleaved = w2_scales.reshape( + w2_s_shape[0], w2_s_shape[1], + (w2_s_shape[2] // module.interleave[1]), module.interleave[1]) + w2_scales_interleaved = w2_scales_interleaved.permute(0, 2, 1, 3) + w2_scales_interleaved = w2_scales_interleaved.reshape( + w2_s_shape[0], w2_s_shape[2] // module.interleave[1], + w2_s_shape[1] * module.interleave[1]) + module.fc2_weight_scale.data.copy_(w2_scales_interleaved.contiguous()) + + +class WFP4A16FusedMoEMethod(FusedMoEMethodBase): + + group_size = 32 + + def create_weights(self, module: torch.nn.Module): + module.sm_version = get_sm_version() + if module.sm_version == 90: + module.interleave = [] + for k_shape in [ + module.hidden_size, module.intermediate_size_per_partition + ]: + module.interleave.append(128 // self.group_size) + else: + raise NotImplementedError( + f"WFP4A16 MoE is unsupported on SM{module.sm_version}.") + weight_dtype = torch.uint8 + w3_w1_weight_shape = (module.expert_size_per_partition, + module.intermediate_size_per_partition * 2, + module.hidden_size // 2) + w2_weight_shape = (module.expert_size_per_partition, module.hidden_size, + module.intermediate_size_per_partition // 2) + + # col parallel + assert module.hidden_size % (self.group_size * + module.interleave[0]) == 0 + scale_dtype = torch.uint8 + fc31_weight_scale = nn.Parameter(torch.empty( + module.expert_size_per_partition, + module.hidden_size // (self.group_size * module.interleave[0]), + module.intermediate_size_per_partition * 2 * module.interleave[0], + dtype=scale_dtype), + requires_grad=False) + module.register_parameter("fc31_weight_scale", fc31_weight_scale) + + # row parallel + assert module.intermediate_size_per_partition % ( + self.group_size * module.interleave[1]) == 0 + fc2_weight_scale = nn.Parameter( + torch.empty(module.expert_size_per_partition, + module.intermediate_size_per_partition // + (self.group_size * module.interleave[1]), + module.hidden_size * module.interleave[1], + dtype=scale_dtype), + requires_grad=False) + module.register_parameter("fc2_weight_scale", fc2_weight_scale) + + super().create_weights(module, weight_dtype, w3_w1_weight_shape, + w2_weight_shape) + self.setup_quant_scales(module) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = FusedMoEQuantScalesW4A16MXFP4( + scale_1_interleaved=module.fc31_weight_scale, + scale_2_interleaved=module.fc2_weight_scale) + + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + assert module.smart_router + return FusedMoEQuantScalesW4A16MXFP4( + scale_1_interleaved=module.fc31_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + scale_2_interleaved=module.fc2_weight_scale.narrow( + 0, slot_start, slot_end - slot_start)) + + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + """ + Load w1 and w3 weights for each expert. + Override this method if you need to preprocess the weights differently. + """ + device = dst_w3_w1_weight.device + assert device.type == "cuda" + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + pad_size_inter = module.intermediate_size_per_partition - w3_weight_shard.shape[ + 0] + if w3_weight_shard.ndim == 2: + pad_size_hidden = module.hidden_size // 2 - w3_weight_shard.shape[1] + pad_shape = (0, pad_size_hidden, 0, pad_size_inter) + elif w3_weight_shard.ndim == 1: + pad_shape = (0, pad_size_inter) + else: + raise NotImplementedError( + f"Invalid shape of w1_weight_shard {w1_weight_shard.shape} and w3_weight_shard {w1_weight_shard.shape}" + ) + + w1_weight_shard = torch.nn.functional.pad(w1_weight_shard, pad_shape) + w3_weight_shard = torch.nn.functional.pad(w3_weight_shard, pad_shape) + + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + + dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), + non_blocking=True) + + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + """ + Load w2 weight for each expert. + Override this method if you need to preprocess the weights differently. + """ + device = dst_w2_weight.device + assert device.type == "cuda" + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + + pad_size_hidden = module.hidden_size - w2_weight_shard.shape[0] + if w2_weight_shard.ndim == 2: + pad_size_inter = module.intermediate_size_per_partition // 2 - w2_weight_shard.shape[ + 1] + pad_shape = (0, pad_size_inter, 0, pad_size_hidden) + elif w2_weight_shard.ndim == 1: + pad_shape = (0, pad_size_hidden) + else: + raise NotImplementedError( + f"Invalid shape of w2_weight_shard {w2_weight_shard.shape}") + + w2_weight_shard = torch.nn.functional.pad(w2_weight_shard, pad_shape) + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), + non_blocking=True) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + device = module.fc31_weight_scale.data.device + assert device.type == "cuda" + + # fc31 scales + assert (len(module.interleave) == 2) + + all_w3_scales = [] + all_w1_scales = [] + for expert_id in module.initial_local_expert_ids: + w3_scale_shard = load_weight_shard( + weights[f"{expert_id}.w3.weight_scale_inv"], + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w1_scale_shard = load_weight_shard( + weights[f"{expert_id}.w1.weight_scale_inv"], + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + pad_size_hidden = module.hidden_size // self.group_size - w3_scale_shard.shape[ + 1] + pad_size_inter = module.intermediate_size_per_partition - w3_scale_shard.shape[ + 0] + w3_scale_shard = torch.nn.functional.pad( + w3_scale_shard, (0, pad_size_hidden, 0, pad_size_inter)) + w1_scale_shard = torch.nn.functional.pad( + w1_scale_shard, (0, pad_size_hidden, 0, pad_size_inter)) + + all_w3_scales.append(w3_scale_shard) + all_w1_scales.append(w1_scale_shard) + + all_w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-2) + + w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view(module.dtype) + w3_w1_s_shape = w3_w1_scales.shape + w3_w1_scales_interleaved = w3_w1_scales.reshape( + w3_w1_s_shape[0], w3_w1_s_shape[1], + (w3_w1_s_shape[2] // module.interleave[0]), module.interleave[0]) + w3_w1_scales_interleaved = w3_w1_scales_interleaved.permute(0, 2, 1, 3) + w3_w1_scales_interleaved = w3_w1_scales_interleaved.reshape( + w3_w1_s_shape[0], w3_w1_s_shape[2] // module.interleave[0], + w3_w1_s_shape[1] * module.interleave[0]) + module.fc31_weight_scale.data.copy_( + w3_w1_scales_interleaved.contiguous()) + + # fc2 scales + all_w2_scales = [] + for expert_id in module.initial_local_expert_ids: + w2_scales_shard = load_weight_shard( + weights[f"{expert_id}.w2.weight_scale_inv"], + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + pad_size_hidden = module.hidden_size - w2_scales_shard.shape[0] + pad_size_inter = module.intermediate_size_per_partition // self.group_size - w2_scales_shard.shape[ + 1] + w2_scales_shard = torch.nn.functional.pad( + w2_scales_shard, (0, pad_size_inter, 0, pad_size_hidden)) + all_w2_scales.append(w2_scales_shard) + + w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view( + module.dtype) w2_s_shape = w2_scales.shape w2_scales_interleaved = w2_scales.reshape( w2_s_shape[0], w2_s_shape[1], @@ -995,6 +1679,14 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase): expert_idx = local_slot_id + if not torch.allclose(w1_weight_scale_2, w3_weight_scale_2): + logger.warning( + f"w1_weight_scale_2 != w3_weight_scale_2 ({w1_weight_scale_2} != {w3_weight_scale_2}), selecting the larger value. Accuracy may be affected." + ) + w1_weight_scale_2 = torch.max(w1_weight_scale_2, + w3_weight_scale_2) + w3_weight_scale_2 = w1_weight_scale_2 + self.load_expert_w3_w1_weight_scale_nvfp4( module, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale[expert_idx]) @@ -1048,7 +1740,7 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase): module.w3_w1_weight_scale.data, module.w2_weight_scale.data, module.fc31_alpha.data, module.fc2_alpha.data) - # Step 3: if need load into shared + # Step 3: if needed, load into shared if self.need_load_shared_weights(module): local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( ) @@ -1130,12 +1822,18 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod): self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, dst_w3_w1_weight_scale: torch.Tensor): - w1_weight_scale = load_weight_shard(w1_weight_scale, module.tp_size, + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w3_w1_weight_scale.device + w1_weight_scale = load_weight_shard(w1_weight_scale, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) - w3_weight_scale = load_weight_shard(w3_weight_scale, module.tp_size, + TensorParallelMode.COLUMN, + device=device) + w3_weight_scale = load_weight_shard(w3_weight_scale, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) + TensorParallelMode.COLUMN, + device=device) # Keep weights in device buffer # w3 dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow( @@ -1153,7 +1851,7 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod): orig_shape = dst_w3_w1_weight_scale.shape - dst_w3_w1_weight_scale_interleaved = torch.ops.trtllm.nvfp4_block_scale_interleave( + dst_w3_w1_weight_scale_interleaved = torch.ops.trtllm.block_scale_interleave( dst_w3_w1_weight_scale.view(float4_sf_dtype)).view( self.block_scales_dtype).reshape(orig_shape) @@ -1164,16 +1862,20 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod): def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, w2_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor): - w2_weight_scale = load_weight_shard(w2_weight_scale, module.tp_size, + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w2_weight_scale.device + w2_weight_scale = load_weight_shard(w2_weight_scale, + module.tp_size, module.tp_rank, - TensorParallelMode.ROW) + TensorParallelMode.ROW, + device=device) # Keep weights in device buffer dst_w2_weight_scale.copy_( w2_weight_scale.view(dst_w2_weight_scale.dtype)) orig_shape = dst_w2_weight_scale.shape - dst_w2_weight_scale_interleaved = torch.ops.trtllm.nvfp4_block_scale_interleave( + dst_w2_weight_scale_interleaved = torch.ops.trtllm.block_scale_interleave( dst_w2_weight_scale.view(float4_sf_dtype)).view( self.block_scales_dtype).reshape(orig_shape) @@ -1190,49 +1892,6 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): # This assumes the same input shape always results in the same permute indices _cache_permute_indices: Dict[torch.Size, torch.Tensor] = {} - def _maybe_get_cached_w3_w1_permute_indices( - self, - dst_w3_w1_weight: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: Union[None, int] = None) -> torch.Tensor: - if dst_w3_w1_weight.shape not in self._cache_permute_indices: - # Get permute indices and chain them together - permute0 = get_reorder_rows_for_gated_act_gemm_row_indices( - dst_w3_w1_weight) - if num_elts_per_sf is None: - permute1 = get_shuffle_matrix_a_row_indices( - dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m) - else: - permute1 = get_shuffle_matrix_sf_a_row_indices( - dst_w3_w1_weight, - epilogue_tile_m=epilogue_tile_m, - num_elts_per_sf=num_elts_per_sf) - # Memoize permute indices as recompute is **very** costly - self._cache_permute_indices[ - dst_w3_w1_weight.shape] = permute0[permute1].to( - dst_w3_w1_weight.device) - permute_indices = self._cache_permute_indices[dst_w3_w1_weight.shape] - return permute_indices - - def _maybe_get_cached_w2_permute_indices( - self, - dst_w2_weight: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: Union[None, int] = None) -> torch.Tensor: - if dst_w2_weight.shape not in self._cache_permute_indices: - if num_elts_per_sf is None: - permute_indices = (get_shuffle_matrix_a_row_indices( - dst_w2_weight, epilogue_tile_m).to(dst_w2_weight.device)) - else: - permute_indices = (get_shuffle_matrix_sf_a_row_indices( - dst_w2_weight, - epilogue_tile_m=epilogue_tile_m, - num_elts_per_sf=num_elts_per_sf).to(dst_w2_weight.device)) - # Memoize permute indices as recompute is **very** costly - self._cache_permute_indices[dst_w2_weight.shape] = permute_indices - permute_indices = self._cache_permute_indices[dst_w2_weight.shape] - return permute_indices - def create_weights(self, module: torch.nn.Module): weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 block_scales_vec_size = 1 @@ -1261,12 +1920,18 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): w1_weight: torch.Tensor, w3_weight: torch.Tensor, dst_w3_w1_weight: torch.Tensor): - w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, + device = dst_w3_w1_weight.device + assert device.type == "cuda" + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) - w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) + TensorParallelMode.COLUMN, + device=device) # FIXME: this depends on the kernel internals epilogue_tile_m = 128 @@ -1279,8 +1944,8 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype)) # Get permute indices - permute_indices = self._maybe_get_cached_w3_w1_permute_indices( - dst_w3_w1_weight, epilogue_tile_m) + permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( + dst_w3_w1_weight, self._cache_permute_indices, epilogue_tile_m) # Shuffle the weight according to permute indices processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix( @@ -1294,9 +1959,13 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): def load_expert_w2_weight(self, module: torch.nn.Module, w2_weight: torch.Tensor, dst_w2_weight: torch.Tensor): - w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, + device = dst_w2_weight.device + assert device.type == "cuda" + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, module.tp_rank, - TensorParallelMode.ROW) + TensorParallelMode.ROW, + device=device) # FIXME: this depends on the kernel internals epilogue_tile_m = 128 @@ -1305,8 +1974,8 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), non_blocking=True) # Get permuted indices - permute_indices = self._maybe_get_cached_w2_permute_indices( - dst_w2_weight, epilogue_tile_m) + permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( + dst_w2_weight, self._cache_permute_indices, epilogue_tile_m) # Shuffle the weight according to permute indices processed_w2_weight = torch.ops.trtllm.shuffle_matrix( @@ -1320,12 +1989,18 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, dst_w3_w1_weight_scale: torch.Tensor): - w1_weight_scale = load_weight_shard(w1_weight_scale, module.tp_size, + device = dst_w3_w1_weight_scale.device + assert device.type == "cuda" + w1_weight_scale = load_weight_shard(w1_weight_scale, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) - w3_weight_scale = load_weight_shard(w3_weight_scale, module.tp_size, + TensorParallelMode.COLUMN, + device=device) + w3_weight_scale = load_weight_shard(w3_weight_scale, + module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) + TensorParallelMode.COLUMN, + device=device) # Keep weights in device buffer # w3 dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow( @@ -1347,8 +2022,9 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): epilogue_tile_m = 128 # FIXME # Get permute indices - permute_indices = self._maybe_get_cached_w3_w1_permute_indices( + permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( dst_w3_w1_weight_scale.view(float4_sf_dtype), + self._cache_permute_indices, epilogue_tile_m, num_elts_per_sf=16) @@ -1359,7 +2035,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): # Assert should only be removed during debugging assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed" # Interleave the weight. - processed_w3_w1_weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( + processed_w3_w1_weight_scale = torch.ops.trtllm.block_scale_interleave( w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape)) # Copy the result into device buffer dst_w3_w1_weight_scale.copy_( @@ -1369,9 +2045,13 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, w2_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor): - w2_weight_scale = load_weight_shard(w2_weight_scale, module.tp_size, + device = dst_w2_weight_scale.device + assert device.type == "cuda" + w2_weight_scale = load_weight_shard(w2_weight_scale, + module.tp_size, module.tp_rank, - TensorParallelMode.ROW) + TensorParallelMode.ROW, + device=device) # Keep weights in device buffer dst_w2_weight_scale.copy_( w2_weight_scale.view(dst_w2_weight_scale.dtype)) @@ -1385,8 +2065,9 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): assert dst_w2_weight_scale.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" # Get permute indices - permute_indices = self._maybe_get_cached_w2_permute_indices( + permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( dst_w2_weight_scale.view(float4_sf_dtype), + self._cache_permute_indices, epilogue_tile_m, num_elts_per_sf=16) @@ -1394,7 +2075,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): w_shuffled = torch.ops.trtllm.shuffle_matrix( dst_w2_weight_scale.view(dtype=float4_sf_dtype), permute_indices) # Interleave the weight. - processed_w2_weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( + processed_w2_weight_scale = torch.ops.trtllm.block_scale_interleave( w_shuffled) # Copy the result into device buffer dst_w2_weight_scale.copy_( @@ -1408,3 +2089,761 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): module.fc31_scale_c.data.copy_(module.fc2_input_scale.data * module.fc31_alpha.data, non_blocking=True) + + +class MXFP4WeightFusedMoEMethod(FusedMoEMethodBase): + + def create_weights(self, + module: torch.nn.Module, + weight_dtype, + weight_vec_size, + block_scales_dtype, + block_scales_vec_size, + weight_alignment=1, + bias_dtype=None): + + def round_up(x, alignment): + return (x + alignment - 1) // alignment * alignment + + module.scaling_vector_size = 32 + intermediate_size_per_partition_padded = round_up( + module.intermediate_size_per_partition, weight_alignment) + hidden_size_padded = round_up(module.hidden_size, weight_alignment) + + w3_w1_weight_shape = (module.expert_size_per_partition, + intermediate_size_per_partition_padded * 2, + hidden_size_padded // weight_vec_size) + w2_weight_shape = (module.expert_size_per_partition, hidden_size_padded, + intermediate_size_per_partition_padded // + weight_vec_size) + + # column parallel + assert hidden_size_padded % (module.scaling_vector_size * + block_scales_vec_size) == 0 + w3_w1_weight_scale = nn.Parameter( + torch.empty(module.expert_size_per_partition, + intermediate_size_per_partition_padded * 2, + hidden_size_padded // module.scaling_vector_size // + block_scales_vec_size, + dtype=block_scales_dtype), + requires_grad=False) + module.register_parameter("w3_w1_weight_scale", w3_w1_weight_scale) + + # row parallel + assert intermediate_size_per_partition_padded % ( + module.scaling_vector_size * block_scales_vec_size) == 0 + w2_weight_scale = nn.Parameter( + torch.empty(module.expert_size_per_partition, + hidden_size_padded, + intermediate_size_per_partition_padded // + module.scaling_vector_size // block_scales_vec_size, + dtype=block_scales_dtype), + requires_grad=False) + module.register_parameter("w2_weight_scale", w2_weight_scale) + + w3_w1_bias_shape = (module.expert_size_per_partition, + intermediate_size_per_partition_padded * 2) + w2_bias_shape = (module.expert_size_per_partition, hidden_size_padded) + + super().create_weights(module, weight_dtype, w3_w1_weight_shape, + w2_weight_shape, bias_dtype, w3_w1_bias_shape, + w2_bias_shape) + + self.setup_quant_scales(module) + + @abstractmethod + def load_expert_w3_w1_weight_scale_mxfp4( + self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + dst_w3_w1_weight_scale: torch.Tensor): + pass + + @abstractmethod + def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module, + w2_weight_scale: torch.Tensor, + dst_w2_weight_scale: torch.Tensor): + pass + + def load_all_mxfp4_weight_scales(self, module: torch.nn.Module, + weights: Dict, load_expert_ids: List[int], + dst_w3_w1_weight_scale: torch.Tensor, + dst_w2_weight_scale: torch.Tensor): + for local_slot_id, expert_id in enumerate(load_expert_ids): + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() + w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk( + 2, dim=0) + w2_weight_scale = weights["down_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + + expert_idx = local_slot_id + + self.load_expert_w3_w1_weight_scale_mxfp4( + module, w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight_scale[expert_idx]) + self.load_expert_w2_weight_scale_mxfp4( + module, w2_weight_scale, dst_w2_weight_scale[expert_idx]) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Step1: Load weight block scales. + self.load_all_mxfp4_weight_scales(module, weights, + module.initial_local_expert_ids, + module.w3_w1_weight_scale.data, + module.w2_weight_scale.data) + + # Step 2: if needed, load into shared + if self.need_load_shared_weights(module): + local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( + ) + local_shared_w3_w1_scale_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w3_w1_weight_scale.data.shape[1:], + dtype=module.w3_w1_weight_scale.data.dtype, + device='cpu') + local_shared_w2_scale_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w2_weight_scale.data.shape[1:], + dtype=module.w2_weight_scale.data.dtype, + device='cpu') + + self.load_all_mxfp4_weight_scales(module, weights, + local_shared_load_expert_ids, + local_shared_w3_w1_scale_tensors, + local_shared_w2_scale_tensors) + + module.register_all_parameter_slot_and_to_fix_weight_fns({ + 'w3_w1_weight_scale': + local_shared_w3_w1_scale_tensors, + 'w2_weight_scale': + local_shared_w2_scale_tensors, + }) + + @abstractmethod + def setup_quant_scales(self, module: torch.nn.Module): + pass + + @abstractmethod + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + pass + + +class MXFP4WeightCutlassFusedMoEMethod(MXFP4WeightFusedMoEMethod): + weight_dtype = FUSED_MOE_MXFP4_WEIGHT_DTYPE + block_scales_dtype = FUSED_MOE_MXFP4_WEIGHT_BLOCK_SCALE_DTYPE + # Cutlass MoE backend requires weight elements to be 128 aligned. + weight_alignment = 128 + + def create_weights(self, module: torch.nn.Module): + weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 + block_scales_vec_size = torch.iinfo(self.block_scales_dtype).bits // 8 + + super().create_weights(module, self.weight_dtype, weight_vec_size, + self.block_scales_dtype, block_scales_vec_size, + self.weight_alignment) + + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN) + w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN) + + if len(w1_weight_shard.shape) == 2: + # Pad weights + # We already satisfy alignment factor 2 for we pad 2 MXFP4 into Uint8. + assert w1_weight_shard.dtype == torch.uint8 + w1_weight_shard = maybe_pad_for_mxfp4(w1_weight_shard, + self.weight_alignment // 2, + self.weight_alignment) + assert w3_weight_shard.dtype == torch.uint8 + w3_weight_shard = maybe_pad_for_mxfp4(w3_weight_shard, + self.weight_alignment // 2, + self.weight_alignment) + else: + # Pad bias. + assert len(w1_weight_shard.shape) == 1 + w1_weight_shard = maybe_pad_for_mxfp4(w1_weight_shard, + self.weight_alignment) + w3_weight_shard = maybe_pad_for_mxfp4(w3_weight_shard, + self.weight_alignment) + + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), + non_blocking=True) + + # Helper function + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + """ + Load w2 weight for each expert. + Override this method if you need to preprocess the weights differently. + """ + w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.ROW) + + if len(w2_weight_shard.shape) == 2: + # Pad weights + # We already satisfy alignment factor 2 for we pad two MXFP4 into Uint8. + assert w2_weight_shard.dtype == torch.uint8 + w2_weight_shard = maybe_pad_for_mxfp4(w2_weight_shard, + self.weight_alignment // 2, + self.weight_alignment) + else: + assert len(w2_weight_shard.shape) == 1 + # Pad bias. + w2_weight_shard = maybe_pad_for_mxfp4(w2_weight_shard, + self.weight_alignment) + + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), + non_blocking=True) + + def load_expert_w3_w1_weight_scale_mxfp4( + self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + dst_w3_w1_weight_scale: torch.Tensor): + device = dst_w3_w1_weight_scale.device + assert device.type == "cuda" + w1_weight_scale = load_weight_shard(w1_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_scale = load_weight_shard(w3_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w1_weight_scale = maybe_pad_for_mxfp4( + w1_weight_scale, + self.weight_alignment // module.scaling_vector_size, + self.weight_alignment) + w3_weight_scale = maybe_pad_for_mxfp4( + w3_weight_scale, + self.weight_alignment // module.scaling_vector_size, + self.weight_alignment) + + # Keep weights in device buffer + dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale.chunk( + 2, dim=0) + dst_w3_weight_scale.copy_( + w3_weight_scale.view(dst_w3_weight_scale.dtype)) + dst_w1_weight_scale.copy_( + w1_weight_scale.view(dst_w1_weight_scale.dtype)) + + orig_shape = dst_w3_w1_weight_scale.shape + + dst_w3_w1_weight_scale.copy_( + torch.ops.trtllm.block_scale_interleave( + dst_w3_w1_weight_scale.view(float4_sf_dtype)).view( + self.block_scales_dtype).reshape(orig_shape)) + + def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module, + w2_weight_scale: torch.Tensor, + dst_w2_weight_scale: torch.Tensor): + device = dst_w2_weight_scale.device + assert device.type == "cuda" + w2_weight_scale = load_weight_shard(w2_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + w2_weight_scale = maybe_pad_for_mxfp4( + w2_weight_scale, + self.weight_alignment // module.scaling_vector_size, + self.weight_alignment) + + # Keep weights in device buffer + dst_w2_weight_scale.copy_( + w2_weight_scale.view(dst_w2_weight_scale.dtype)) + + orig_shape = dst_w2_weight_scale.shape + + dst_w2_weight_scale.copy_( + torch.ops.trtllm.block_scale_interleave( + dst_w2_weight_scale.view(float4_sf_dtype)).view( + self.block_scales_dtype).reshape(orig_shape)) + + +class W4A16MXFP4CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod): + pass + + +class W4A8MXFP4MXFP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod): + + def create_weights(self, module: torch.nn.Module): + fake_input_scale = nn.Parameter(torch.empty( + module.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + module.register_parameter("fake_input_scale", fake_input_scale) + + super().create_weights(module) + + self.setup_quant_scales(module) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Step1: Load input scales. + module.fake_input_scale.fill_(1.) + + # Step2: Load weight block scales. + super().load_quant_scales(module, weights) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = FusedMoEQuantScalesW4A8MXFP4MXFP8( + fc31_weight_block_scale=module.w3_w1_weight_scale, + fc31_dequant_scale=module.fake_input_scale, + fc2_weight_block_scale=module.w2_weight_scale, + fc2_dequant_scale=module.fake_input_scale, + ) + + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + assert module.smart_router + return FusedMoEQuantScalesW4A8MXFP4MXFP8( + fc31_weight_block_scale=module.w3_w1_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc31_dequant_scale=module.fake_input_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc2_weight_block_scale=module.w2_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc2_dequant_scale=module.fake_input_scale.narrow( + 0, slot_start, slot_end - slot_start), + ) + + +class W4A8MXFP4FP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod): + + def create_weights(self, module: torch.nn.Module): + fc31_input_scale = nn.Parameter(torch.tensor(1., dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc31_input_scale", fc31_input_scale) + + fc31_input_dequant = nn.Parameter(torch.empty( + module.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc31_input_dequant", fc31_input_dequant) + + fc2_input_scale = nn.Parameter(torch.tensor(1., dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc2_input_scale", fc2_input_scale) + + fc2_input_dequant = nn.Parameter(torch.empty( + module.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc2_input_dequant", fc2_input_dequant) + + super().create_weights(module) + + self.setup_quant_scales(module) + + def load_expert_fc31_input_scale_w4a8_mxfp4_fp8( + self, w1_input_scale, w3_input_scale, + dst_fc31_input_scale: torch.Tensor): + w1_input_scale = w1_input_scale[...].reshape([]) + assert torch.allclose( + w1_input_scale, w3_input_scale), "w1_input_scale != w3_input_scale" + dst_fc31_input_scale.copy_(w1_input_scale) + + def load_expert_fc2_input_scale_w4a8_mxfp4_fp8( + self, w2_input_scale, dst_fc2_input_scale: torch.Tensor): + dst_fc2_input_scale.copy_(w2_input_scale[...].reshape([])) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Step1: Load input scales. + tmp_fc31_input_scale = torch.empty(module.num_experts, + dtype=torch.float32) + tmp_fc2_input_scale = torch.empty(module.num_experts, + dtype=torch.float32) + + for expert_id in range(module.num_experts): + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_input_scale = weights[f"{expert_id}.w1.input_scale"] + w3_input_scale = weights[f"{expert_id}.w3.input_scale"] + w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_input_scale = weights["gate_up_proj_input_scale"] + w3_input_scale = weights["gate_up_proj_input_scale"] + w2_input_scale = weights["down_proj_input_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + + self.load_expert_fc31_input_scale_w4a8_mxfp4_fp8( + w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id]) + self.load_expert_fc2_input_scale_w4a8_mxfp4_fp8( + w2_input_scale, tmp_fc2_input_scale[expert_id]) + + # fc31_input_scale is the reciprocal of the maximum of all w1 input scales and w3 input scales. + module.fc31_input_scale.data.copy_( + tmp_fc31_input_scale.max().reciprocal()) + module.fc31_input_dequant.data.copy_(tmp_fc31_input_scale.max()) + # fc2_input_scale is the reciprocal of the maximum of all w2 input scales. + module.fc2_input_scale.data.copy_( + tmp_fc2_input_scale.max().reciprocal()) + module.fc2_input_dequant.data.copy_(tmp_fc2_input_scale.max()) + + # Step2: Load weight block scales. + super().load_quant_scales(module, weights) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = FusedMoEQuantScalesW4A8MXFP4FP8( + fc31_weight_block_scale=module.w3_w1_weight_scale, + fc31_dequant_scale=module.fc31_input_dequant, + fc2_input_scale=module.fc2_input_scale, + fc2_weight_block_scale=module.w2_weight_scale, + fc2_dequant_scale=module.fc2_input_dequant, + ) + + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + assert module.smart_router + return FusedMoEQuantScalesW4A8MXFP4FP8( + fc31_weight_block_scale=module.w3_w1_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc31_dequant_scale=module.fc31_input_dequant.narrow( + 0, slot_start, slot_end - slot_start), + fc2_input_scale=module.fc2_input_scale, + fc2_weight_block_scale=module.w2_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc2_dequant_scale=module.fc2_input_dequant.narrow( + 0, slot_start, slot_end - slot_start), + ) + + +class MXFP4WeightTRTLLMGenFusedMoEMethod(MXFP4WeightFusedMoEMethod): + weight_dtype = torch.uint8 + block_scales_dtype = torch.uint8 + # TRTLLM-Gen backend requires weight elements to be 256 aligned. + weight_alignment = 256 + + # Cache the permute indices during weight loading to avoid recompute + # This assumes the same input shape always results in the same permute indices + _cache_permute_indices: Dict[torch.Size, torch.Tensor] = {} + + def create_weights(self, module: torch.nn.Module): + weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 + block_scales_vec_size = torch.iinfo(self.block_scales_dtype).bits // 8 + + super().create_weights(module, + self.weight_dtype, + weight_vec_size, + self.block_scales_dtype, + block_scales_vec_size, + self.weight_alignment, + bias_dtype=torch.float32) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = tuple() + + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + """ + The TRTLLM-Gen backend of FusedMoE does not use FusedMoEQuantScales. + """ + raise NotImplementedError + + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + device = dst_w3_w1_weight.device + assert device.type == "cuda" + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + if len(w1_weight_shard.shape) == 2: + # Pad weights + # We already satisfy alignment factor of 2 for we pad two MXFP4 into Uint8. + assert w1_weight_shard.dtype == torch.uint8 + w1_weight_shard = maybe_pad_for_mxfp4(w1_weight_shard, + self.weight_alignment // 2, + self.weight_alignment) + assert w3_weight_shard.dtype == torch.uint8 + w3_weight_shard = maybe_pad_for_mxfp4(w3_weight_shard, + self.weight_alignment // 2, + self.weight_alignment) + else: + # Pad bias, TRTLLM backend expects float32 bias. + assert len(w1_weight_shard.shape) == 1 + w1_weight_shard = maybe_pad_for_mxfp4( + w1_weight_shard, self.weight_alignment).float() + w3_weight_shard = maybe_pad_for_mxfp4( + w3_weight_shard, self.weight_alignment).float() + + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Keep weights in device buffer + dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0) + dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype)) + dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype)) + + # Get permute indices + permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( + dst_w3_w1_weight, self._cache_permute_indices, epilogue_tile_m) + + # Shuffle the weight according to permute indices + processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix( + dst_w3_w1_weight, permute_indices.to(dst_w3_w1_weight.device)) + + # Copy the result into device buffer + dst_w3_w1_weight.copy_(processed_w31_weight_shard.view( + dst_w3_w1_weight.dtype), + non_blocking=True) + + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + device = dst_w2_weight.device + assert device.type == "cuda" + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + + if len(w2_weight_shard.shape) == 2: + # Pad weights + # We already satisfy alignment factor of 2 for we pad two MXFP4 into Uint8. + assert w2_weight_shard.dtype == torch.uint8 + w2_weight_shard = maybe_pad_for_mxfp4(w2_weight_shard, + self.weight_alignment // 2, + self.weight_alignment) + else: + # Pad bias, TRTLLM backend expects float32 bias. + # Divide bias by tp_size as we shard along the hidden dimension. + # The bias is applied at each TP rank before the final accumulation. + assert len(w2_weight_shard.shape) == 1 + w2_weight_shard = maybe_pad_for_mxfp4( + w2_weight_shard, self.weight_alignment).float() / module.tp_size + + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Keep weights in device buffer + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), + non_blocking=True) + # Get permuted indices + permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( + dst_w2_weight, self._cache_permute_indices, epilogue_tile_m) + + # Shuffle the weight according to permute indices + processed_w2_weight = torch.ops.trtllm.shuffle_matrix( + dst_w2_weight, permute_indices.to(dst_w2_weight.device)) + + # Copy the result into device buffer + dst_w2_weight.copy_(processed_w2_weight.view(dst_w2_weight.dtype), + non_blocking=True) + + def load_expert_w3_w1_weight_scale_mxfp4( + self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + dst_w3_w1_weight_scale: torch.Tensor): + device = dst_w3_w1_weight_scale.device + assert device.type == "cuda" + w1_weight_scale = load_weight_shard(w1_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_scale = load_weight_shard(w3_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w1_weight_scale = maybe_pad_for_mxfp4( + w1_weight_scale, + self.weight_alignment // module.scaling_vector_size, + self.weight_alignment) + w3_weight_scale = maybe_pad_for_mxfp4( + w3_weight_scale, + self.weight_alignment // module.scaling_vector_size, + self.weight_alignment) + + # Keep weights in device buffer + dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale.chunk( + 2, dim=0) + dst_w3_weight_scale.copy_( + w3_weight_scale.view(dst_w3_weight_scale.dtype)) + dst_w1_weight_scale.copy_( + w1_weight_scale.view(dst_w1_weight_scale.dtype)) + + orig_shape = dst_w3_w1_weight_scale.shape + + # trtllm-gen specific block scales preprocessing logics + epilogue_tile_m = 128 # FIXME + + # Get permute indices + permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( + dst_w3_w1_weight_scale.view(float4_sf_dtype), + self._cache_permute_indices, + epilogue_tile_m, + num_elts_per_sf=32) + + # Shuffle the weight according to permute indices + w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix( + dst_w3_w1_weight_scale.view(float4_sf_dtype), permute_indices) + + # Assert should only be removed during debugging + assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed" + # Interleave the weight. + processed_w3_w1_weight_scale = torch.ops.trtllm.block_scale_interleave( + w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape)) + # Copy the result into device buffer + dst_w3_w1_weight_scale.copy_( + processed_w3_w1_weight_scale.view( + self.block_scales_dtype).reshape(orig_shape)) + + def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module, + w2_weight_scale: torch.Tensor, + dst_w2_weight_scale: torch.Tensor): + device = dst_w2_weight_scale.device + assert device.type == "cuda" + # The last rank might get not full tensor, but its remainder. + # E.g. TP=8 and w2_weight_scale.shape[1] = 90, the last rank will get 6 elements. + # Take the original width, pad it to the self.weight_alignment // module.scaling_vector_size, + # Use this value as padding for the weight scales. + original_width = math.ceil(w2_weight_scale.shape[1] / module.tp_size) + sfs_alignment = self.weight_alignment // module.scaling_vector_size + padded_width = math.ceil(original_width / sfs_alignment) * sfs_alignment + w2_weight_scale = load_weight_shard(w2_weight_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + w2_weight_scale = maybe_pad_for_mxfp4(w2_weight_scale, padded_width, + self.weight_alignment) + + # Keep weights in device buffer + dst_w2_weight_scale.copy_( + w2_weight_scale.view(dst_w2_weight_scale.dtype)) + + orig_shape = dst_w2_weight_scale.shape + + # trtllm-gen specific block scales preprocessing logics + epilogue_tile_m = 128 # FIXME: read from kernel + + # Assert should only be removed during debugging + assert dst_w2_weight_scale.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" + + # Get permute indices + permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( + dst_w2_weight_scale.view(float4_sf_dtype), + self._cache_permute_indices, + epilogue_tile_m, + num_elts_per_sf=32) + + # Shuffle the weight according to permute indices + w_shuffled = torch.ops.trtllm.shuffle_matrix( + dst_w2_weight_scale.view(dtype=float4_sf_dtype), permute_indices) + # Interleave the weight. + processed_w2_weight_scale = torch.ops.trtllm.block_scale_interleave( + w_shuffled) + # Copy the result into device buffer + dst_w2_weight_scale.copy_( + processed_w2_weight_scale.view( + self.block_scales_dtype).reshape(orig_shape)) + + +class W4A16MXFP4TRTLLMGenFusedMoEMethod(MXFP4WeightTRTLLMGenFusedMoEMethod): + pass + + +class W4A8MXFP4FP8TRTLLMGenFusedMoEMethod(MXFP4WeightTRTLLMGenFusedMoEMethod): + + def create_weights(self, module: torch.nn.Module): + fc31_input_dequant = nn.Parameter(torch.empty( + module.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc31_input_dequant", fc31_input_dequant) + fc31_input_gate_dequant = nn.Parameter(torch.empty( + module.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc31_input_gate_dequant", + fc31_input_gate_dequant) + + fc2_input_dequant = nn.Parameter(torch.empty( + module.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + module.register_parameter("fc2_input_dequant", fc2_input_dequant) + + super().create_weights(module) + + def load_expert_fc31_input_scale_w4a8_mxfp4_fp8( + self, w1_input_scale, w3_input_scale, w2_input_scale, + dst_fc31_input_scale: torch.Tensor, + dst_fc2_input_scale: torch.Tensor): + w1_input_scale = w1_input_scale[...].reshape([]) + w2_input_scale = w2_input_scale[...].reshape([]) + assert torch.allclose( + w1_input_scale, w3_input_scale), "w1_input_scale != w3_input_scale" + dst_fc31_input_scale.copy_(w1_input_scale) + dst_fc2_input_scale.copy_(w2_input_scale) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Step1: Load input scales. + tmp_fc31_input_scale = torch.empty(module.num_experts, + dtype=torch.float32) + + tmp_fc2_input_scale = torch.empty(module.num_experts, + dtype=torch.float32) + + for expert_id in range(module.num_experts): + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_input_scale = weights[f"{expert_id}.w1.input_scale"] + w3_input_scale = weights[f"{expert_id}.w3.input_scale"] + w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_input_scale = weights["gate_up_proj_input_scale"] + w3_input_scale = weights["gate_up_proj_input_scale"] + w2_input_scale = weights["down_proj_input_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + + self.load_expert_fc31_input_scale_w4a8_mxfp4_fp8( + w1_input_scale, w3_input_scale, w2_input_scale, + tmp_fc31_input_scale[expert_id], tmp_fc2_input_scale[expert_id]) + + module.fc31_input_dequant.data.copy_(tmp_fc31_input_scale.max() / + tmp_fc2_input_scale.max()) + module.fc31_input_gate_dequant.data.copy_(tmp_fc31_input_scale.max()) + module.fc2_input_dequant.data.copy_(tmp_fc2_input_scale.max()) + + # Step2: Load weight block scales. + super().load_quant_scales(module, weights) + + +class W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod(MXFP4WeightTRTLLMGenFusedMoEMethod): + + def create_weights(self, module: torch.nn.Module): + super().create_weights(module) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Load weight block scales. + super().load_quant_scales(module, weights) diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 793240c2ad..34c2179593 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -5,6 +5,138 @@ from typing import Optional import torch from torch import nn +# Global cache for perfect router logits to share across all MLP blocks +_PERFECT_ROUTER_LOGITS_CACHE = {} + + +def get_perfect_router_cache_stats(): + """Get statistics about the perfect router cache.""" + global _PERFECT_ROUTER_LOGITS_CACHE + + if not _PERFECT_ROUTER_LOGITS_CACHE: + return { + "cache_size": 0, + "memory_usage_mb": 0.0, + "cached_batch_sizes": [] + } + + total_memory = 0 + cached_batch_sizes = [] + + for (num_tokens, num_experts, experts_per_token, + moe_ep_size), logits in _PERFECT_ROUTER_LOGITS_CACHE.items(): + total_memory += logits.numel() * logits.element_size() + cached_batch_sizes.append(num_tokens) + + return { + "cache_size": len(_PERFECT_ROUTER_LOGITS_CACHE), + "memory_usage_mb": total_memory / (1024 * 1024), + "cached_batch_sizes": sorted(list(set(cached_batch_sizes))) + } + + +def precompute_common_perfect_router_logits(num_experts: int, + experts_per_token: int, + moe_ep_size: int, + dtype: torch.dtype): + """ + Pre-compute logits for common batch sizes to avoid first-time computation overhead. + Only precomputes if cache is empty (avoids redundant work across multiple MLPBlock instances). + """ + # Check if cache is already populated (avoid redundant work) + cache_stats = get_perfect_router_cache_stats() + if cache_stats["cache_size"] > 0: + return + + # Common batch sizes for different scenarios + common_batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 384, + 512, + 640, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 5120, + 6144, + 7168, + 8192 # Powers of 2 and common sizes + ] + + print( + f"Precomputing perfect router logits for {len(common_batch_sizes)} common batch sizes..." + ) + + # Precompute logits for common batch sizes using global cache + for num_tokens in common_batch_sizes: + try: + # Use the global cache function which will handle CPU computation and caching + get_cached_perfect_router_logits( + num_tokens=num_tokens, + num_experts=num_experts, + experts_per_token=experts_per_token, + moe_ep_size=moe_ep_size, + device=torch.device('cpu'), # Precompute on CPU + dtype=dtype) + + except Exception as e: + # Skip this batch size if computation fails + print( + f"Warning: Failed to precompute logits for batch size {num_tokens}: {e}" + ) + continue + + # Print cache statistics + final_stats = get_perfect_router_cache_stats() + print( + f"Perfect router cache initialized: {final_stats['cache_size']} entries, " + f"{final_stats['memory_usage_mb']:.2f} MB memory usage") + + +def get_cached_perfect_router_logits(num_tokens: int, num_experts: int, + experts_per_token: int, moe_ep_size: int, + device: torch.device, + dtype: torch.dtype) -> torch.Tensor: + """ + Get cached perfect router logits, computing and caching if not found. + Uses global cache to share across all MLP blocks. + """ + global _PERFECT_ROUTER_LOGITS_CACHE + + cache_key = (num_tokens, num_experts, experts_per_token, moe_ep_size) + + if cache_key in _PERFECT_ROUTER_LOGITS_CACHE: + # Return cached logits moved to the correct device + cached_logits = _PERFECT_ROUTER_LOGITS_CACHE[cache_key] + if cached_logits.device != device: + cached_logits = cached_logits.to(device) + # Update cache with device-specific version for future use + _PERFECT_ROUTER_LOGITS_CACHE[cache_key] = cached_logits + return cached_logits + else: + # Compute and cache new logits + logits = create_renormalize_expert_load_balanced_logits( + num_tokens=num_tokens, + num_experts=num_experts, + experts_per_token=experts_per_token, + moe_ep_size=moe_ep_size, + device=device, + dtype=dtype) + + _PERFECT_ROUTER_LOGITS_CACHE[cache_key] = logits + return logits + # The type of method in top-K routing, for use in torch custom op # Please keep this in sync with the counterpart defined in cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h @@ -37,15 +169,15 @@ class BaseMoeRoutingMethod(nn.Module): """ raise NotImplementedError("Subclasses must implement this method") - def get_experts_per_token(self): + def get_experts_per_token(self) -> int: return self.top_k @property - def experts_per_token(self): + def experts_per_token(self) -> int: return self.get_experts_per_token() @property - def routing_method_type(self): + def routing_method_type(self) -> RoutingMethodType: return RoutingMethodType.Unspecified @@ -267,3 +399,162 @@ class RenormalizeNaiveMoeRoutingMethod(RenormalizeMoeRoutingMethod): @property def routing_method_type(self) -> RoutingMethodType: return RoutingMethodType.RenormalizeNaive + + +def create_renormalize_expert_load_balanced_logits( + num_tokens: int, + num_experts: int, + experts_per_token: int, + moe_ep_size: int, + device: torch.device, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + """ + Create ideal logits that produce GPU-aware expert load balanced assignment for RenormalizeMoeRoutingMethod. + + This function is specifically designed to work with RenormalizeMoeRoutingMethod, which applies + TopK selection first, then Softmax normalization on the selected experts. The function generates + logits with high values for the desired experts and low values for others, ensuring that the + TopK selection picks the intended experts for perfect load balancing. + + This is a GPU-optimized version that avoids Python loops. + + The function creates routing logits that ensure perfect load balancing across GPUs + by cycling through experts in a GPU-aware pattern. Each token is assigned to + exactly k=experts_per_token experts, distributed evenly across all GPUs. + + Strategy: + 1. First cycle through one expert from each GPU (GPU representatives) + 2. Then move to the next expert on each GPU, and so on + 3. This ensures even distribution of work across all GPUs + + Example 1: num_gpus=4, num_experts=8, experts_per_token=2, tokens=3 + experts_per_gpu = 8 // 4 = 2 + gpu_representatives = [0, 2, 4, 6] (first expert from each GPU) + final_size = 3 * 2 = 6 (total expert assignments needed) + + | i_tensor | gpu_idx | expert_offset | indices | Explanation | + |----------|---------|---------------|---------|-------------| + | 0 | 0 | 0 | 0 | GPU 0, expert 0 | + | 1 | 1 | 0 | 2 | GPU 1, expert 0 | + | 2 | 2 | 0 | 4 | GPU 2, expert 0 | + | 3 | 3 | 0 | 6 | GPU 3, expert 0 | + | 4 | 0 | 1 | 1 | GPU 0, expert 1 | + | 5 | 1 | 1 | 3 | GPU 1, expert 1 | + + Reshaped to (3, 2): [[0, 2], [4, 6], [1, 3]] + Token 0 -> experts [0, 2], Token 1 -> experts [4, 6], Token 2 -> experts [1, 3] + + Final GPU Load Balance (Example 1): + - GPU 0: 2 expert calls (expert 0 from token 0, expert 1 from token 2) + - GPU 1: 2 expert calls (expert 0 from token 0, expert 1 from token 2) + - GPU 2: 1 expert call (expert 0 from token 1) + - GPU 3: 1 expert call (expert 0 from token 1) + Note: Slight imbalance due to (3 tokens * 2 experts = 6 total work units) not being divisible by EP size (4 GPUs) + + Example 2: num_gpus=4, num_experts=8, experts_per_token=2, tokens=4 + experts_per_gpu = 8 // 4 = 2 + gpu_representatives = [0, 2, 4, 6] + final_size = 4 * 2 = 8 + + | i_tensor | gpu_idx | expert_offset | indices | Explanation | + |----------|---------|---------------|---------|-------------| + | 0 | 0 | 0 | 0 | GPU 0, expert 0 | + | 1 | 1 | 0 | 2 | GPU 1, expert 0 | + | 2 | 2 | 0 | 4 | GPU 2, expert 0 | + | 3 | 3 | 0 | 6 | GPU 3, expert 0 | + | 4 | 0 | 1 | 1 | GPU 0, expert 1 | + | 5 | 1 | 1 | 3 | GPU 1, expert 1 | + | 6 | 2 | 1 | 5 | GPU 2, expert 1 | + | 7 | 3 | 1 | 7 | GPU 3, expert 1 | + + Reshaped to (4, 2): [[0, 2], [4, 6], [1, 3], [5, 7]] + Token 0 -> experts [0, 2], Token 1 -> experts [4, 6], + Token 2 -> experts [1, 3], Token 3 -> experts [5, 7] + + Final GPU Load Balance (Example 2): + - GPU 0: 2 expert calls (expert 0 from token 0, expert 1 from token 2) + - GPU 1: 2 expert calls (expert 0 from token 0, expert 1 from token 2) + - GPU 2: 2 expert calls (expert 0 from token 1, expert 1 from token 3) + - GPU 3: 2 expert calls (expert 0 from token 1, expert 1 from token 3) + Perfect balance: Each GPU handles exactly 2 expert calls + + Args: + num_tokens: Number of tokens to route + num_experts: Total number of experts + experts_per_token: Number of experts each token should be routed to (top-k) + moe_ep_size: Number of GPUs for MoE expert parallelism + device: Device to create tensors on + dtype: Data type for the logits tensor + + Returns: + torch.Tensor: Logits tensor of shape [num_tokens, num_experts] with softmax-applied probabilities + + Raises: + ValueError: If num_experts is not divisible by moe_ep_size or if moe_ep_size is zero + """ + k = experts_per_token + experts_per_gpu = num_experts // moe_ep_size + # For expert load balance, only moe_ep_size is relevant. System could have multiple TP/gpus sharding each group of experts + num_gpus = moe_ep_size + + # Validation checks + if num_experts % moe_ep_size != 0: + raise ValueError( + f"num_experts ({num_experts}) must be divisible by moe_ep_size ({moe_ep_size})" + ) + + if moe_ep_size == 0: + raise ValueError("moe_ep_size cannot be zero") + + # Create logits tensor on the same device and dtype as input + # Shape: [num_tokens, num_experts] - will hold routing probabilities + logits = torch.zeros(num_tokens, num_experts, device=device, dtype=dtype) + + # GPU-aware expert assignment: cycle through one expert from each GPU first + final_size = num_tokens * k # Total number of expert assignments needed + + # Create GPU representatives (first expert from each GPU): [0, 8, 16, 24, ...] + # These are the starting expert indices for each GPU + gpu_representatives = torch.arange(0, + num_experts, + experts_per_gpu, + device=device) + + # Generate indices using GPU-aware pattern (vectorized) + # i_tensor: sequential indices from 0 to final_size-1 + i_tensor = torch.arange(final_size, device=device) + + # gpu_idx: which GPU this assignment should go to (cycles through 0,1,2,3,0,1,2,3,...) + gpu_idx = i_tensor % num_gpus + + # expert_offset: which expert within the GPU (0,0,0,0,1,1,1,1,2,2,2,2,...) + # This ensures we use all experts from each GPU before moving to next expert + expert_offset = (i_tensor // num_gpus) % experts_per_gpu + + # indices: actual expert indices by combining GPU base + offset + indices = gpu_representatives[gpu_idx] + expert_offset + + # Reshape to (num_tokens, k) - each row contains k expert indices for that token + expert_indices = indices.view(num_tokens, k) + + # Generate large values for selected experts (5-10 range) + # These high values ensure the selected experts have high probability after softmax + large_values = torch.full((num_tokens, k), 7.5, device=device, dtype=dtype) + + # Assign large values to selected expert positions + # token_indices: [[0,0],[1,1],[2,2],...] for indexing tokens + token_indices = torch.arange(num_tokens, + device=device).unsqueeze(1).expand(-1, k) + logits[token_indices, expert_indices] = large_values + + # Fill remaining positions with small values (0-1 range) + # This ensures non-selected experts have low but non-zero probability + mask = (logits == 0) + logits[mask] = 0.5 + + # Apply softmax to get probabilities + # After softmax, selected experts will have high probability (~0.99) + # while non-selected experts will have very low probability + logits = torch.nn.functional.softmax(logits, dim=-1) + + return logits diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 8b3e314a9e..d7c20fe8f0 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -27,7 +27,8 @@ class GatedMLP(nn.Module): config: Optional[ModelConfig] = None, overridden_tp_size: Optional[int] = None, reduce_output: bool = True, - layer_idx: Optional[int] = None): + layer_idx: Optional[int] = None, + use_cute_dsl_blockscaling_mm: bool = False): super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size @@ -64,7 +65,8 @@ class GatedMLP(nn.Module): reduce_output=False, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], [self.hidden_size]) @@ -81,7 +83,8 @@ class GatedMLP(nn.Module): skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.down_lora, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) # These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used, # but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora diff --git a/tensorrt_llm/_torch/modules/layer_norm.py b/tensorrt_llm/_torch/modules/layer_norm.py new file mode 100644 index 0000000000..518863e550 --- /dev/null +++ b/tensorrt_llm/_torch/modules/layer_norm.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +from torch import nn + + +class LayerNorm(nn.Module): + """Layer normalization module with configurable weight and bias parameters. + + This implementation provides standard layer normalization with optional + learnable parameters and residual connection support. + + Args: + hidden_size: The size of the hidden dimension to normalize. + eps: Small constant for numerical stability. + dtype: Optional data type for parameters. + device: Optional device for parameters. + has_weights: Whether to include learnable weight parameters. + has_bias: Whether to include learnable bias parameters. + """ + + def __init__( + self, + *, + hidden_size: int, + eps: float, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + has_weights: bool = True, + has_bias: bool = True, + ): + super().__init__() + if has_weights: + self.weight = nn.Parameter( + torch.ones(hidden_size, dtype=dtype, device=device)) + else: + self.register_buffer('weight', + torch.ones(hidden_size, + dtype=dtype, + device=device), + persistent=False) + if has_bias: + self.bias = nn.Parameter( + torch.zeros(hidden_size, dtype=dtype, device=device)) + else: + self.register_buffer('bias', + torch.zeros(hidden_size, + dtype=dtype, + device=device), + persistent=False) + self.variance_epsilon = eps + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = ..., + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Apply layer normalization to input tensor. + + Args: + hidden_states: Input tensor to normalize. + residual: Optional residual tensor to add before normalization. + + Returns: + Normalized tensor, or tuple of (normalized_tensor, residual) if residual provided. + """ + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + if isinstance(residual, torch.Tensor): + hidden_states = hidden_states + residual.to(torch.float32) + residual = hidden_states.to(input_dtype) + + hidden_states = nn.functional.layer_norm( + hidden_states, + hidden_states.shape[-1], + weight=self.weight, + bias=self.bias, + eps=self.variance_epsilon, + ) + + if residual is ...: + return hidden_states + else: + return hidden_states, residual + + def skip_forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = ..., + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Skip normalization and return inputs unchanged. + + Args: + hidden_states: Input tensor to pass through. + residual: Optional residual tensor to pass through. + + Returns: + Input tensors unchanged, maintaining same signature as forward. + """ + + if residual is ...: + return hidden_states + else: + return hidden_states, residual diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 44d69076fc..67d49b3d94 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -111,7 +111,10 @@ def copy_weight(dst: Parameter, src: torch.Tensor): dst.data.copy_(src) -def load_weights_vanilla_helper(module: Linear, weights: List[Dict]): +def load_weights_vanilla_helper(module: Linear, + weights: List[Dict], + weight_transform=lambda x: x, + bias_transform=lambda x: x): assert len(weights) == 1 device = torch.device('cuda') @@ -127,17 +130,20 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict]): weight.T.to(torch.int8).contiguous().cpu(), weight_dtype, activation_dtype).cuda().contiguous() - copy_weight(module.weight, weight) + copy_weight(module.weight, weight_transform(weight)) if module.bias is not None: bias = load_weight_shard(weights[0]['bias'], module.tp_size, module.tp_rank, module.tp_mode, device) - copy_weight(module.bias, bias) + copy_weight(module.bias, bias_transform(bias)) def load_weights_fused_qkv_helper( - module: Linear, - weights: List[Dict]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + module: Linear, + weights: List[Dict], + weight_transform=lambda x: x, + bias_transform=lambda x: x +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(weights) == 3 device = torch.device('cuda') @@ -155,14 +161,17 @@ def load_weights_fused_qkv_helper( module.tp_rank, module.tp_mode, device) v_bias = load_weight_shard(weights[2]['bias'], module.tp_size, module.tp_rank, module.tp_mode, device) - copy_weight(module.bias, torch.cat((q_bias, k_bias, v_bias))) + copy_weight(module.bias, + bias_transform(torch.cat((q_bias, k_bias, v_bias)))) - return (q_weight, k_weight, v_weight) + return tuple(map(weight_transform, (q_weight, k_weight, v_weight))) def load_weights_fused_gate_up_helper( module: Linear, - weights: List[Dict]) -> tuple[torch.Tensor, torch.Tensor]: + weights: List[Dict], + weight_transform=lambda x: x, + bias_transform=lambda x: x) -> tuple[torch.Tensor, torch.Tensor]: assert len(weights) == 2 device = torch.device('cuda') @@ -175,8 +184,9 @@ def load_weights_fused_gate_up_helper( module.tp_rank, module.tp_mode, device) up_bias = load_weight_shard(weights[1]['bias'], module.tp_size, module.tp_rank, module.tp_mode, device) - copy_weight(module.bias, torch.cat((up_bias, gate_bias))) - return (gate_weight, up_weight) + copy_weight(module.bias, bias_transform(torch.cat( + (gate_bias, up_bias)))) + return tuple(map(weight_transform, (gate_weight, up_weight))) def get_weight_dtype_and_id(module: Linear) -> tuple[torch.dtype, int]: @@ -573,21 +583,29 @@ class FP8BlockScalesLinearMethod(LinearMethodBase): assert input.dtype == torch.bfloat16 if get_sm_version() == 100: - from tensorrt_llm import deep_gemm - a, a_sf = fp8_utils.per_token_quant_and_transform(input) - output = torch.empty((input.shape[0], module.weight.shape[0]), - device=input.device, - dtype=torch.bfloat16) - deep_gemm.fp8_gemm_nt((a, a_sf), - (module.weight, module.weight_scale), - output, - disable_ue8m0_cast=True) + if module.use_cute_dsl_blockscaling_mm: + # TODO (@lmin): replace with cute_dsl gemm + act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( + input) + output = torch.ops.trtllm.fp8_block_scaling_gemm( + act_input_fp8, module.weight, act_input_sf, + module.weight_scale) + else: + from tensorrt_llm import deep_gemm + a, a_sf = fp8_utils.per_token_quant_and_transform(input) + output = torch.empty((input.shape[0], module.weight.shape[0]), + device=input.device, + dtype=torch.bfloat16) + deep_gemm.fp8_gemm_nt((a, a_sf), + (module.weight, module.weight_scale), + output, + disable_ue8m0_cast=True) else: act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( input) - output = torch.ops.trtllm.fp8_block_scaling_gemm( act_input_fp8, module.weight, act_input_sf, module.weight_scale) + if bias is not None: output = output + bias return output @@ -687,6 +705,8 @@ class NVFP4LinearMethod(LinearMethodBase): bias: Optional[torch.Tensor]): if isinstance(input, Fp4QuantizedTensor): act_fp4, act_sf = input.fp4_tensor, input.scaling_factor + elif isinstance(input, tuple): + act_fp4, act_sf = input else: act_fp4, act_sf = torch.ops.trtllm.fp4_quantize( input, module.input_scale, module.scaling_vector_size, False) @@ -752,8 +772,7 @@ class NVFP4LinearMethod(LinearMethodBase): assert len(weights) == 1 weight_scale = weight_scale[0] # Swizzle weight scale - weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( - weight_scale) + weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.input_scale, input_scale) copy_weight(module.weight_scale, weight_scale) @@ -773,8 +792,7 @@ class NVFP4LinearMethod(LinearMethodBase): tp_mode=module.tp_mode) # Swizzle weight scales after concatenation weight_scale = torch.cat(weight_scales, 0) - weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( - weight_scale) + weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.input_scale, input_scale) copy_weight(module.weight_scale, weight_scale) copy_weight(module.alpha, alpha) @@ -796,8 +814,7 @@ class NVFP4LinearMethod(LinearMethodBase): tp_mode=module.tp_mode) # Swizzle weight scales after concatenation weight_scale = torch.cat(weight_scales, 0) - weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( - weight_scale) + weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.input_scale, input_scale) copy_weight(module.weight_scale, weight_scale) copy_weight(module.alpha, alpha) @@ -880,8 +897,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase): assert len(weights) == 1 weight_scale = weight_scale[0] # Swizzle weight scale - weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( - weight_scale) + weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.weight_scale, weight_scale) def load_weights_fused_qkv_linear(self, module: Linear, @@ -896,8 +912,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase): tp_rank=module.tp_rank, tp_mode=module.tp_mode) weight_scale = torch.cat(weight_scale, 0) - weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( - weight_scale) + weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.weight_scale, weight_scale) def load_weights_fused_gate_up_linear(self, module: Linear, @@ -913,8 +928,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase): tp_mode=module.tp_mode) # Swizzle weight scales after concatenation weight_scale = torch.cat(weight_scale, 0) - weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave( - weight_scale) + weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.weight_scale, weight_scale) @@ -1113,8 +1127,9 @@ class W4A16_AWQ_LinearMethod(LinearMethodBase): def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) - device = torch.device('cuda') - + # Use the same device as the weight tensor + # as we register pre_quant_scale after sharded model weights are moved to respective gpus + device = module.weight.device pre_quant_scale = load_weight_shard( weights[0]["pre_quant_scale"], module.tp_size, @@ -1210,6 +1225,10 @@ class W4A8_AWQ_LinearMethod(LinearMethodBase): module.alpha = Parameter(torch.empty([1], dtype=torch.float32), requires_grad=False) + # WAR for CUDA graph. Mixed w4a8 gemm does not accept alpha in device buffer. + # Hence we prepare a separate plain float to be updated during the weight load. + module.alpha_value = 1.0 + if bias: module.bias = Parameter(torch.empty((out_features), dtype=dtype), requires_grad=False) @@ -1246,7 +1265,7 @@ class W4A8_AWQ_LinearMethod(LinearMethodBase): has_zero_point=module.quant_config.has_zero_point, output_dtype=module.dtype or input.dtype, # NOTE: output_dtype can only be bf16/fp16 for W4A8 - alpha=module.alpha.item(), + alpha=module.alpha_value, bias=bias, zeros=None) @@ -1294,7 +1313,9 @@ class W4A8_AWQ_LinearMethod(LinearMethodBase): def load_weights_vanilla(self, module: Linear, weights: List[Dict]): load_weights_vanilla_helper(module, weights) - device = torch.device('cuda') + # Use the same device as the weight tensor + # as we register pre_quant_scale after sharded model weights are moved to respective gpus + device = module.weight.device pre_quant_scale = load_weight_shard( weights[0]["pre_quant_scale"], module.tp_size, @@ -1326,6 +1347,8 @@ class W4A8_AWQ_LinearMethod(LinearMethodBase): copy_weight(module.input_scale, input_scale) copy_weight(module.alpha, alpha) + module.alpha_value = alpha.item() + module.inv_input_scale.data = 1.0 / module.input_scale def load_weights_fused_qkv_linear(self, module: Linear, @@ -1354,17 +1377,20 @@ class W4A8_AWQ_LinearMethod(LinearMethodBase): copy_weight(module.input_scale, input_scale) copy_weight(module.alpha, alpha) + module.alpha_value = alpha.item() # NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer shared the same input and create an avg pre_quant_scale # Usually when modelopt exports the quantized model, pre_quant_Scale is fused in the layer norm (this case relevant if fused is disabled - modelopt internal) if "pre_quant_scale" in weights[0].keys(): - + # Use the same device as the weight tensor + # as we register pre_quant_scale after sharded model weights are moved to respective gpus + device = module.weight.device pre_quant_scale = load_weight_shard( weights[0]["pre_quant_scale"], module.tp_size, module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around TensorParallelMode.flip(module.tp_mode), - torch.device('cuda'), + device, ) module.pre_quant_scale = Parameter( @@ -1398,14 +1424,19 @@ class W4A8_AWQ_LinearMethod(LinearMethodBase): copy_weight(module.input_scale, input_scale) copy_weight(module.alpha, alpha) + module.alpha_value = alpha.item() + if "pre_quant_scale" in weights[0].keys(): + # Use the same device as the weight tensor + # as we register pre_quant_scale after sharded model weights are moved to respective gpus + device = module.weight.device pre_quant_scale = load_weight_shard( weights[0]["pre_quant_scale"], module.tp_size, module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around TensorParallelMode.flip(module.tp_mode), - torch.device('cuda'), + device, ) # NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16) @@ -1416,6 +1447,27 @@ class W4A8_AWQ_LinearMethod(LinearMethodBase): copy_weight(module.pre_quant_scale, pre_quant_scale) +class W4A8MXFP4MXFP8LinearMethod(W4A8MXFP4FP8LinearMethod): + + def create_weights(self, module: Linear, in_features: int, + out_features: int, bias: bool, dtype: torch.dtype): + super().create_weights(module, in_features, out_features, bias, dtype) + module.scale_one = torch.tensor([1.0], dtype=torch.float32).cuda() + + def apply(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor]): + # requires the swizzled block scales. + fp8_input, input_scales = torch.ops.trtllm.mxfp8_quantize(input, True) + output = torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(fp8_input, module.weight, + input_scales, + module.weight_scale, + module.scale_one, + module.dtype) + if bias is not None: + output = output + bias + return output + + def get_quant_method(quant_config: Optional[QuantConfig] = None): if quant_config is None or not quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -1439,6 +1491,8 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None): if quant_config.layer_quant_mode.is_int4_weight_only_per_group( ) and quant_config.quant_algo == QuantAlgo.W4A8_AWQ: return W4A8_AWQ_LinearMethod() + if quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8(): + return W4A8MXFP4MXFP8LinearMethod() raise ValueError(f'unsupported quant mode: {quant_config.quant_mode}') @@ -1461,6 +1515,7 @@ class Linear(nn.Module): lora: Optional[LoraLayer] = None, allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, force_dynamic_quantization: bool = False, + use_cute_dsl_blockscaling_mm: bool = False, ): from ..distributed import AllReduce @@ -1477,6 +1532,7 @@ class Linear(nn.Module): self.tp_mode = tensor_parallel_mode self.gather_output = gather_output self.force_dynamic_quantization = force_dynamic_quantization + self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm local_in_features = in_features local_out_features = out_features @@ -1498,9 +1554,9 @@ class Linear(nn.Module): self.in_features = local_in_features self.out_features = local_out_features - self.all_reduce = AllReduce( - mapping=self.mapping, - strategy=allreduce_strategy) if reduce_output else None + self.all_reduce = AllReduce(mapping=self.mapping, + strategy=allreduce_strategy, + dtype=self.dtype) if reduce_output else None self._weights_created = False self.reduce_output = reduce_output self.use_custom_cublas_mm = use_custom_cublas_mm @@ -1516,11 +1572,14 @@ class Linear(nn.Module): if not skip_create_weights_in_init: self.create_weights() + def get_quant_method(self, quant_config: Optional[QuantConfig] = None): + return get_quant_method(quant_config) + def create_weights(self): if self._weights_created: return - self.quant_method = get_quant_method(self.quant_config) + self.quant_method = self.get_quant_method(self.quant_config) self.quant_method.create_weights(self, self.in_features, self.out_features, self.has_bias, self.dtype) @@ -1575,6 +1634,12 @@ class Linear(nn.Module): return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( ) and self.quant_config.quant_algo == QuantAlgo.W4A8_AWQ + @property + def has_w4a8_mxfp4_fp8(self): + assert self._weights_created + return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8( + ) + def apply_linear(self, input, bias, diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 7b5d65a7a1..6ea096bb6a 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -147,6 +147,8 @@ class Mamba2Mixer(nn.Module): quant_config=config.get_quant_config(), allreduce_strategy=config.allreduce_strategy) + self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype + def forward( self, hidden_states: torch.Tensor, @@ -230,6 +232,7 @@ class Mamba2Mixer(nn.Module): seq_idx=seq_idx, return_varlen_states=True, return_final_states=False, + mamba_ssm_cache_dtype=self._mamba_ssm_cache_dtype, ) out.append(rearrange(y, "b l h p -> (b l) (h p)")) diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py index 42f4eb7d77..0a6f18bb63 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py @@ -16,6 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from einops import rearrange @@ -43,6 +45,7 @@ def _mamba_chunk_scan_combined_fwd( cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + mamba_ssm_cache_dtype=None, ): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -120,7 +123,7 @@ def _mamba_chunk_scan_combined_fwd( if initial_states is not None else None), seq_idx=seq_idx, chunk_size=chunk_size, - out_dtype=C.dtype, + out_dtype=mamba_ssm_cache_dtype or C.dtype, is_cont_batched=cu_seqlens is not None) states, final_states = [ rearrange(t, "... (p n) -> ... p n", n=dstate) @@ -174,24 +177,26 @@ def _mamba_chunk_scan_combined_fwd( return out, out_x, dt, dA_cumsum, states, final_states, varlen_states -def mamba_chunk_scan_combined(x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_final_states=False, - return_varlen_states=False): +def mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False, + mamba_ssm_cache_dtype: Optional[torch.dtype] = None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -207,6 +212,7 @@ def mamba_chunk_scan_combined(x, seq_idx: (batch, seqlen) cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt + mamba_ssm_cache_dtype: torch.dtype, default to None Return: out: (batch, seqlen, nheads, headdim) """ @@ -231,7 +237,8 @@ def mamba_chunk_scan_combined(x, chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, - dt_limit=dt_limit) + dt_limit=dt_limit, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) if not return_varlen_states: return out if not return_final_states else (out, final_states) else: diff --git a/tensorrt_llm/_torch/modules/multi_stream_utils.py b/tensorrt_llm/_torch/modules/multi_stream_utils.py index e91b7eac24..c7b58c0896 100644 --- a/tensorrt_llm/_torch/modules/multi_stream_utils.py +++ b/tensorrt_llm/_torch/modules/multi_stream_utils.py @@ -1,8 +1,35 @@ +import threading +from contextlib import contextmanager from typing import Any, Callable, Optional import torch -from ..pyexecutor.cuda_graph_runner import is_graph_capturing + +class do_multi_stream_local(threading.local): + + def __init__(self): + self.do_multi_stream = False + + +_local = do_multi_stream_local() + + +def set_do_multi_stream(enable: bool): + _local.do_multi_stream = enable + + +def do_multi_stream() -> bool: + return _local.do_multi_stream + + +@contextmanager +def with_multi_stream(enable: bool): + prev_do_multi_stream = _local.do_multi_stream + set_do_multi_stream(enable) + try: + yield + finally: + set_do_multi_stream(prev_do_multi_stream) def maybe_execute_in_parallel( @@ -30,9 +57,9 @@ def maybe_execute_in_parallel( tuple[Any, Any]: the return values of fn0() and fn1() """ - do_multi_stream = is_graph_capturing() and aux_stream is not None + multi_stream = do_multi_stream() and aux_stream is not None - if do_multi_stream: + if multi_stream: event0.record() result0 = fn0() diff --git a/tensorrt_llm/_torch/modules/rms_norm.py b/tensorrt_llm/_torch/modules/rms_norm.py index 5eef7a6d00..39787b82b7 100644 --- a/tensorrt_llm/_torch/modules/rms_norm.py +++ b/tensorrt_llm/_torch/modules/rms_norm.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import enum from typing import Optional, Tuple, Union @@ -9,13 +24,15 @@ from ..custom_ops import IS_FLASHINFER_AVAILABLE class RMSNorm(nn.Module): - def __init__(self, - *, - hidden_size: int, - eps: float, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - has_weights: bool = True): + def __init__( + self, + *, + hidden_size: int, + eps: float, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + has_weights: bool = True, + ): super().__init__() if has_weights: self.weight = nn.Parameter( @@ -48,6 +65,7 @@ class RMSNorm(nn.Module): if isinstance(residual, torch.Tensor): hidden_states = hidden_states + residual.to(torch.float32) residual = hidden_states.to(input_dtype) + variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) diff --git a/tensorrt_llm/_torch/modules/triton_linear.py b/tensorrt_llm/_torch/modules/triton_linear.py new file mode 100644 index 0000000000..dfc3d584e6 --- /dev/null +++ b/tensorrt_llm/_torch/modules/triton_linear.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from tensorrt_llm._torch.peft.lora.layer import LoraLayer +from tensorrt_llm.mapping import Mapping + +from ...models.modeling_utils import QuantConfig +# Reuse the common Triton import setup +from .fused_moe.fused_moe_triton import (IS_TRITON_KERNELS_AVAILABLE, + maybe_update_stride, + swizzle_weight_and_scale) + +if IS_TRITON_KERNELS_AVAILABLE: + from triton_kernels.matmul_ogs import (FlexCtx, PrecisionConfig, matmul_ogs) + from triton_kernels.numerics import InFlexData + +from .linear import (Linear, LinearMethodBase, TensorParallelMode, + WeightsLoadingConfig, copy_weight, load_weight_shard, + load_weights_fused_gate_up_helper, + load_weights_fused_qkv_helper, load_weights_vanilla_helper) + + +class TritonUnquantizedLinearMethod(LinearMethodBase): + + def __init__(self): + super().__init__() + self.param_transform = { + "weight_transform": lambda x: x.T.unsqueeze(0), + "bias_transform": lambda x: x.unsqueeze(0) + } + + def create_weights(self, module: Linear, in_features: int, + out_features: int, bias: bool, dtype: torch.dtype): + weight_shape = (1, in_features, out_features) + module.weight = Parameter(torch.empty(weight_shape, dtype=dtype), + requires_grad=False) + + if bias: + module.bias = Parameter( + torch.empty((1, out_features), dtype=torch.float32 + ), # Triton kernels expect bias in float32 + requires_grad=False) + else: + module.register_parameter("bias", None) + + def apply(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor]): + output = matmul_ogs( + input, + module.weight, + module.bias, + None, # Routing data is not used here + gather_indx=None, + precision_config=None) + return output + + def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + load_weights_vanilla_helper(module, weights, **self.param_transform) + module.weight.data = maybe_update_stride(module.weight.data) + + def load_weights_fused_qkv_linear(self, module: Linear, + weights: List[Dict]): + q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( + module, weights, **self.param_transform) + fused_weight = torch.cat( + (q_weight, k_weight, v_weight), axis=-1 + ) #Each of them has shape (1, in_features, out_features_part) + copy_weight(module.weight, fused_weight) + module.weight.data = maybe_update_stride(module.weight.data) + + def load_weights_fused_gate_up_linear(self, module: Linear, + weights: List[Dict]): + gate_weight, up_weight = load_weights_fused_gate_up_helper( + module, weights, **self.param_transform) + fused_weight = torch.cat( + (gate_weight, up_weight), axis=-1 + ) #Each of them has shape (1, in_features, out_features_part) + copy_weight(module.weight, fused_weight) + module.weight.data = maybe_update_stride(module.weight.data) + + +class TritonFP8QDQLinearMethod(LinearMethodBase): + + def __init__(self): + super().__init__() + self.param_transform = { + "weight_transform": lambda x: x.T.unsqueeze(0), + "bias_transform": lambda x: x.unsqueeze(0) + } + + def create_weights(self, module: Linear, in_features: int, + out_features: int, bias: bool, dtype: torch.dtype): + weight_shape = (1, in_features, out_features) + module.weight = Parameter(torch.empty(weight_shape, + dtype=torch.float8_e4m3fn), + requires_grad=False) + module.weight_scale = Parameter(torch.empty((1, ), dtype=torch.float32), + requires_grad=False) + module.input_scale = Parameter(torch.empty((1, ), dtype=torch.float32), + requires_grad=False) + + if bias: + module.bias = Parameter( + torch.empty((1, out_features), dtype=torch.float32 + ), # Triton kernels expect bias in float32 + requires_grad=False) + else: + module.register_parameter("bias", None) + + def apply(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor]): + cur_input_scale = module.input_scale + if input.dtype != torch.float8_e4m3fn: + if module.input_scale is not None: + # Static quantization + qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + input, module.input_scale) + else: + # Dynamic quantization + qinput, cur_input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( + input) + cur_input_scale = cur_input_scale.to(torch.float32) + + else: + qinput = input + + flex_ctx = FlexCtx( + lhs_data=InFlexData(scale=cur_input_scale), + rhs_data=InFlexData(scale=module.weight_scale), + ) + pc = PrecisionConfig(flex_ctx=flex_ctx, + allow_tf32=False, + out_dtype=module.dtype) + output = matmul_ogs( + qinput, + module.weight, + module.bias, + None, # Routing data is not used here + gather_indx=None, + precision_config=pc) + return output + + def load_weight_scales(self, weights: List[Dict]): + input_scale, weight_scale = [], [] + for w in weights: + if "input_scale" in w: + input_scale.append(w["input_scale"][...].reshape((1, ))) + if "weight_scale" in w: + weight_scale.append(w["weight_scale"][...].reshape((1, ))) + return input_scale, weight_scale + + def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + load_weights_vanilla_helper(module, weights, **self.param_transform) + input_scale, weight_scale = self.load_weight_scales(weights) + if len(input_scale) != 0: + # Static quantization + copy_weight(module.input_scale, input_scale[0]) + else: + # Dynamic quantization + module.input_scale = None + copy_weight(module.weight_scale, weight_scale[0]) + module.weight.data = maybe_update_stride(module.weight.data) + + def load_weights_fused_qkv_linear(self, module: Linear, + weights: List[Dict]): + q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( + module, weights, **self.param_transform) + + input_scale, weight_scale = self.load_weight_scales(weights) + if len(input_scale) != 0: + # Static quantization + copy_weight(module.input_scale, max(input_scale)) + else: + # Dynamic quantization + module.input_scale = None + copy_weight(module.weight_scale, max(weight_scale)) + + q_weight = q_weight.to(module.dtype) * weight_scale[0] + k_weight = k_weight.to(module.dtype) * weight_scale[1] + v_weight = v_weight.to(module.dtype) * weight_scale[2] + + fused_weight = torch.cat((q_weight, k_weight, v_weight)) + fused_weight = (fused_weight / module.weight_scale).to( + torch.float8_e4m3fn) + copy_weight(module.weight, + self.param_transform["weight_transform"](fused_weight)) + module.weight.data = maybe_update_stride(module.weight.data) + + def load_weights_fused_gate_up_linear(self, module: Linear, + weights: List[Dict]): + input_scale, weight_scale = self.load_weight_scales(weights) + if len(input_scale) != 0: + # Static quantization + copy_weight(module.input_scale, max(input_scale)) + else: + # Dynamic quantization + module.input_scale = None + copy_weight(module.weight_scale, max(weight_scale)) + + gate_weight, up_weight = load_weights_fused_gate_up_helper( + module, weights, **self.param_transform) + + gate_weight = gate_weight.to(module.dtype) * weight_scale[0] + up_weight = up_weight.to(module.dtype) * weight_scale[1] + fused_weight = torch.cat((gate_weight, up_weight)) + fused_weight = (fused_weight / module.weight_scale).to( + torch.float8_e4m3fn) + copy_weight(module.weight, + self.param_transform["weight_transform"](fused_weight)) + module.weight.data = maybe_update_stride(module.weight.data) + + +class TritonMXFP4LinearMethod(LinearMethodBase): + + def __init__(self, activation_dtype): + super().__init__() + assert activation_dtype in [torch.float8_e4m3fn, torch.bfloat16], \ + f"TritonMXFP4LinearMethod only supports float8_e4m3fn or bfloat16 activation, got {activation_dtype}" + self.activation_dtype = activation_dtype + + def create_weights(self, module: Linear, in_features: int, + out_features: int, bias: bool, dtype: torch.dtype): + # Create weight + assert in_features % 2 == 0, "in_features must be even for MXFP4" + weight_shape = (1, in_features // 2, out_features) + module.weight = Parameter(torch.empty(weight_shape, dtype=torch.uint8), + requires_grad=False) + + # Create weight scale + scale_shape = (1, in_features // 32, out_features + ) # Block size is 32 for MXFP4 + module.weight_scale = Parameter(torch.empty(scale_shape, + dtype=torch.uint8), + requires_grad=False) + + # Create bias + if bias: + module.bias = Parameter( + torch.empty((1, out_features), dtype=torch.float32 + ), # Triton kernels expect bias in float32 + requires_grad=False) + else: + module.bias = None + + # Create input scale + if self.activation_dtype == torch.float8_e4m3fn: + module.input_scale = Parameter(torch.empty((1, ), + dtype=torch.float32), + requires_grad=False) + else: + module.input_scale = None + + def apply(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor]): + if self.activation_dtype == torch.float8_e4m3fn: + if input.dtype != torch.float8_e4m3fn: + if module.input_scale is not None: + # Static quantization + input, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + input, module.input_scale) + input_scale = module.input_scale + else: + # Dynamic quantization + input, input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( + input) + else: + assert module.input_scale is not None + input_scale = module.input_scale + + if self.activation_dtype == torch.float8_e4m3fn: + flex_ctx = FlexCtx(lhs_data=InFlexData(scale=input_scale), ) + else: + flex_ctx = FlexCtx() + pc = PrecisionConfig(weight_scale=module.weight_scale, + flex_ctx=flex_ctx, + allow_tf32=False, + out_dtype=module.dtype) + output = matmul_ogs( + input, + module.weight, + module.bias, + None, # Routing data is not used here + gather_indx=None, + precision_config=pc) + return output + + def load_weights_common(self, module: Linear, weights_list: List[Dict]): + device = torch.device('cuda') + processed_weights = [] + weight_scales = [] + biases = [] + input_scales = [] + for w in weights_list: + current_weight = load_weight_shard(w['weight'], module.tp_size, + module.tp_rank, module.tp_mode, + device) + current_scale = load_weight_shard(w['weight_scale'], module.tp_size, + module.tp_rank, module.tp_mode, + device) + current_bias = load_weight_shard( + w['bias'], module.tp_size, module.tp_rank, module.tp_mode, + device) if module.bias is not None else None + + processed_weights.append(current_weight) + weight_scales.append(current_scale) + if current_bias is not None: + biases.append(current_bias) + if "input_scale" in w: + input_scales.append(w["input_scale"][...].reshape([])) + # handle weights + fused_weight = torch.cat( + processed_weights) # (out_features, in_features//2) + fused_weight = fused_weight.T.unsqueeze( + 0) # (1, in_features//2, out_features) + + # handle scales + fused_scale = torch.cat( + weight_scales) # (out_features, in_features//32) + fused_scale = fused_scale.T.unsqueeze( + 0) # (1, in_features//32, out_features) + fused_weight, fused_scale = swizzle_weight_and_scale( + fused_weight, fused_scale) + assert module.weight_scale.dtype == fused_scale.dtype + # We need to use Triton tensor wrapper instead of Torch tensor to maintain the correct swizzling layout + module._parameters.pop('weight', None) + module._parameters.pop('weight_scale', None) + torch.cuda.empty_cache() + module.weight = fused_weight + module.weight_scale = fused_scale + + # handle biases + if module.bias is not None: + fused_bias = torch.cat(biases) # (out_features, ) + fused_bias = fused_bias.unsqueeze(0) # (1, out_features) + copy_weight(module.bias, fused_bias) + + # handle input scales + if len(input_scales) != 0: + # Static quantization + max_input_scale = torch.tensor(max(input_scales)).reshape((1, )) + copy_weight(module.input_scale, max_input_scale) + else: + # Dynamic quantization + module.input_scale = None + + def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + assert len(weights) == 1 + self.load_weights_common(module, weights) + + def load_weights_fused_qkv_linear(self, module: Linear, + weights: List[Dict]): + assert len(weights) == 3 + self.load_weights_common(module, weights) + + def load_weights_fused_gate_up_linear(self, module: Linear, + weights: List[Dict]): + assert len(weights) == 2 + self.load_weights_common(module, weights) + + +class TritonLinear(Linear): + """ + A Linear module that uses Triton for the forward pass. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + mapping: Optional[Mapping] = None, + tensor_parallel_mode: Optional[TensorParallelMode] = None, + gather_output: bool = False, # COLUMN parallel only + quant_config: Optional[QuantConfig] = None, + weights_loading_config: Optional[WeightsLoadingConfig] = None, + reduce_output: bool = True, # ROW parallel only + skip_create_weights_in_init: bool = False, + use_custom_cublas_mm: bool = False, + lora: Optional[LoraLayer] = None, + ): + if not IS_TRITON_KERNELS_AVAILABLE: + raise ImportError("Triton kernels are not available. " + "Please install the required dependencies.") + assert not use_custom_cublas_mm, "TritonLinear does not support custom cublas mm." + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + dtype=dtype, + mapping=mapping, + tensor_parallel_mode=tensor_parallel_mode, + gather_output=gather_output, + quant_config=quant_config, + weights_loading_config=weights_loading_config, + reduce_output=reduce_output, + skip_create_weights_in_init=skip_create_weights_in_init, + use_custom_cublas_mm=use_custom_cublas_mm, + lora=lora) + + # Most of the code can be reused, only change the quant method offloading here. + def get_quant_method(self, quant_config: Optional[QuantConfig] = None): + if quant_config is None or not quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True): + return TritonUnquantizedLinearMethod() + if quant_config.layer_quant_mode.has_fp8_qdq(): + return TritonFP8QDQLinearMethod() + if quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8(): + return TritonMXFP4LinearMethod(activation_dtype=torch.float8_e4m3fn) + if quant_config.layer_quant_mode.has_w4a16_mxfp4(): + assert self.dtype == torch.bfloat16, "Only bfloat16 is supported for W4A16 MXFP4" + return TritonMXFP4LinearMethod(activation_dtype=self.dtype) + raise ValueError(f'unsupported quant mode: {quant_config.quant_mode}') diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index b7204afeb4..fed6a71537 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -8,15 +8,14 @@ import torch import tensorrt_llm import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig -from tensorrt_llm.llmapi.llm_args import PeftCacheConfig +from tensorrt_llm.llmapi.llm_args import PeftCacheConfig, SamplerType from tensorrt_llm.logger import logger -from tensorrt_llm.lora_manager import (LoraConfig, - get_default_trtllm_modules_to_hf_modules, - load_torch_lora) -from tensorrt_llm.mapping import Mapping +from tensorrt_llm.lora_helper import (LoraConfig, + get_default_trtllm_modules_to_hf_modules) +from tensorrt_llm.lora_manager import load_torch_lora +from tensorrt_llm.mapping import CpType, Mapping from ..model_config import ModelConfig from ..speculative import get_num_extra_kv_tokens, get_spec_decoder @@ -25,11 +24,11 @@ from .config_utils import is_mla, is_nemotron_hybrid from .guided_decoder import GuidedDecoder from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver from .llm_request import ExecutorResponse +from .mamba_cache_manager import MambaHybridCacheManager from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor -from .resource_manager import (KVCacheManager, MambaHybridCacheManager, - PeftCacheManager, ResourceManager, - ResourceManagerType) +from .resource_manager import (KVCacheManager, PeftCacheManager, + ResourceManager, ResourceManagerType) from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, SimpleScheduler) @@ -243,8 +242,8 @@ class KvCacheCreator: torch_used_bytes = torch.cuda.memory_stats( )["allocated_bytes.all.current"] finally: - py_executor.shutdown() py_executor.is_warmup = False + py_executor.shutdown() py_executor.enable_iter_perf_stats = origin_iter_stats py_executor.set_gather_responses(False) @@ -314,6 +313,7 @@ class KvCacheCreator: dtype=kv_cache_dtype, spec_config=spec_config, max_beam_width=executor_config.max_beam_width, + is_draft=model_engine.is_draft_model, ) elif is_nemotron_hybrid(config): if executor_config.max_beam_width > 1: @@ -329,6 +329,7 @@ class KvCacheCreator: mamba_layer_mask = [ char == "M" for char in config.hybrid_override_pattern ] + kv_cache_manager = MambaHybridCacheManager( # mamba cache parameters config.ssm_state_size, @@ -339,6 +340,8 @@ class KvCacheCreator: mamba_num_layers, mamba_layer_mask, config.torch_dtype, + model_engine.model.model_config.quant_config. + mamba_ssm_cache_dtype, # kv cache parameters executor_config.kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, @@ -376,6 +379,7 @@ class KvCacheCreator: max_num_tokens=executor_config.max_num_tokens, model_config=binding_model_config, max_beam_width=executor_config.max_beam_width, + is_draft=model_engine.is_draft_model, ) # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER: @@ -502,6 +506,7 @@ def create_py_executor_instance( ) peft_cache_manager = PeftCacheManager( peft_cache_config=executor_config.peft_cache_config, + lora_config=lora_config, model_config=model_binding_config, world_config=world_config, ) @@ -523,8 +528,14 @@ def create_py_executor_instance( resource_manager.resource_managers.move_to_end( ResourceManagerType.KV_CACHE_MANAGER, last=True) + # When scheduler_capacity == 1, attention dp dummy request will prevent the scheduling of DISAGG_GENERATION_INIT. + # Enlarge scheduler capacity to avoid DISAGG_GENERATION_INIT stuck in the scheduler. + scheduler_capacity = max_num_sequences + if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager: + scheduler_capacity += 1 + capacity_scheduler = BindCapacityScheduler( - max_num_sequences, + scheduler_capacity, kv_cache_manager.impl if kv_cache_manager is not None else None, peft_cache_manager.impl if peft_cache_manager is not None else None, executor_config.scheduler_config.capacity_scheduler_policy, @@ -583,14 +594,17 @@ def instantiate_sampler(engine: PyTorchModelEngine, mapping, max_seq_len=engine.max_seq_len, enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler) - if mapping.cp_config.get('cp_type') == 'star_attention': + decoding_mode = get_decoding_mode(executor_config) + if mapping.cp_config.get('cp_type') == CpType.STAR: assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'" return TorchSampler(sampler_args) if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder( ): return get_spec_decoder(sampler_args, engine.spec_config) - if pytorch_backend_config.enable_trtllm_sampler: - decoding_mode = get_decoding_mode(executor_config) + if pytorch_backend_config.sampler_type == SamplerType.TRTLLMSampler or ( + pytorch_backend_config.sampler_type == SamplerType.auto + and decoding_mode.isBeamSearch()): + logger.debug(f"DecodingMode: {decoding_mode.name}") return TRTLLMSampler(executor_config, engine.model, engine.dtype, mapping, decoding_mode, pytorch_backend_config.disable_overlap_scheduler) @@ -618,92 +632,6 @@ def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode: ) decoding_mode = DecodingMode.TopKTopP() - # Override decoding mode when Medusa is used - if executor_config.speculative_config and executor_config.speculative_config.is_medusa and not decoding_mode.isMedusa( - ): - logger.warning( - "Model is Medusa, but decoding mode is not Medusa. Overwriting decoding mode to Medusa." - ) - decoding_mode = DecodingMode.Medusa() - - # Override decoding mode when Medusa is not used - if (not executor_config.speculative_config - or not executor_config.speculative_config.is_medusa - ) and decoding_mode.isMedusa(): - logger.warning( - "Model is not Medusa, but decoding mode is Medusa. Overwriting decoding mode." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when lookahead decoding is used - if executor_config.speculative_config and executor_config.speculative_config.is_lookahead and not decoding_mode.isLookahead( - ): - logger.warning( - "Model is Lookahead, but decoding mode is not Lookahead. Overwriting decoding mode to Lookahead." - ) - decoding_mode = DecodingMode.Lookahead() - - # Override decoding mode when lookahead decoding is not used - if (not executor_config.speculative_config - or not executor_config.speculative_config.is_lookahead - ) and decoding_mode.isLookahead(): - logger.warning( - "Model is not built with Lookahead decoding, but decoding mode is Lookahead. Overwriting decoding mode." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when 'explicit draft tokens' is used - if executor_config.speculative_config and executor_config.speculative_config.is_explicit_draft_tokens and not decoding_mode.isExplicitDraftTokens( - ): - logger.warning( - "Model is built with 'explicit draft tokens' decoding, but decoding mode is something else. Overwriting decoding mode." - ) - decoding_mode = DecodingMode.ExplicitDraftTokens() - - # Override decoding mode when 'explicit draft tokens' is not used - if (not executor_config.speculative_config - or not executor_config.speculative_config.is_explicit_draft_tokens - ) and decoding_mode.isExplicitDraftTokens(): - logger.warning( - "Model is not built with 'explicit draft tokens' decoding, but decoding mode is set to it. Overwriting decoding mode to default." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when EAGLE is used - if executor_config.speculative_config and executor_config.speculative_config.is_eagle and not decoding_mode.isEagle( - ): - logger.warning( - "Model is Eagle, but decoding mode is not Eagle. Overwriting decoding mode to Eagle." - ) - decoding_mode = DecodingMode.Eagle() - - # Override decoding mode when Eagle is not used - if (not executor_config.speculative_config - or not executor_config.speculative_config.is_eagle - ) and decoding_mode.isEagle(): - logger.warning( - "Model is not Eagle, but decoding mode is Eagle. Overwriting decoding mode." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when draft tokens are external - if executor_config.speculative_config and executor_config.speculative_config.is_draft_tokens_external: - logger.warning("Overwriting decoding mode to external draft token") - decoding_mode = DecodingMode.ExternalDraftTokens() - - logger.debug(f"DecodingMode: {decoding_mode.name}") return decoding_mode diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 0770643ae3..c656aac8c6 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -6,7 +6,7 @@ from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \ from tensorrt_llm.bindings.executor import ExecutorConfig from ...builder import BuildConfig -from ...llmapi.llm_args import LoadFormat +from ...llmapi.llm_args import LoadFormat, SamplerType from ...logger import logger from ...mapping import Mapping from ..model_config import MoeLoadBalancerConfig @@ -53,18 +53,22 @@ class PyTorchConfig: attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' + moe_disable_finalize_fusion: bool = False + enable_mixed_sampler: bool = False """ If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc. """ - enable_trtllm_sampler: bool = False + sampler_type: SamplerType = SamplerType.auto """ - If true, will use the TRTLLM sampler instead of the PyTorch sampler. - The TRTLLM sampler has a wide coverage of sampling strategies. + The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. + Defaults to auto, which will use TorchSampler unless BeamSearch is requested. """ kv_cache_dtype: str = "auto" + mamba_ssm_cache_dtype: str = "auto" + enable_iter_perf_stats: bool = False # If true, enables per request stats per iteration # Must also set enable_iter_perf_stats to true to get request stats @@ -75,6 +79,7 @@ class PyTorchConfig: torch_compile_fullgraph: bool = True torch_compile_inductor_enabled: bool = False torch_compile_piecewise_cuda_graph: bool = False + torch_compile_piecewise_cuda_graph_num_tokens: Optional[List[int]] = None # When torch compile is enabled, userbuffers is enabled by default torch_compile_enable_userbuffers: bool = True torch_compile_max_num_streams: int = 1 diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index c0f0482674..914ec6fcd9 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -5,9 +5,7 @@ def is_nemotron_hybrid(config): def is_mla(config): - if hasattr(config, "kv_lora_rank"): - assert hasattr( - config, "qk_rope_head_dim" - ), "both of kv_lora_rank and qk_rope_head_dim are required." + if getattr(config, "kv_lora_rank", None) and getattr( + config, "qk_rope_head_dim", None): return True return False diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 241fc1447c..df674a9496 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -1,28 +1,11 @@ -import threading from typing import Any, Callable, Dict, Optional, Tuple import torch from ..attention_backend.interface import AttentionMetadata +from ..modules.multi_stream_utils import with_multi_stream from ..speculative.interface import SpecMetadata -from ..utils import make_weak_ref, set_piecewise_cuda_graph_flag - - -class graph_capturing_local(threading.local): - - def __init__(self): - self.is_graph_capturing = False - - -_local = graph_capturing_local() - - -def set_graph_capturing(enable: bool): - _local.is_graph_capturing = enable - - -def is_graph_capturing() -> bool: - return _local.is_graph_capturing +from ..utils import make_weak_ref, piecewise_cuda_graph class DecodingCUDAGraphRunner: @@ -34,6 +17,7 @@ class DecodingCUDAGraphRunner: attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, use_mrope: bool = False, + max_beam_width: int = 1, ) -> None: """ Stores a CUDA graph and its associated input buffers. @@ -49,19 +33,21 @@ class DecodingCUDAGraphRunner: e.g. FlashInfer cause graph breaks). """ self.batch_size = batch_size - + self.max_beam_width = max_beam_width # [CUDA graph spec decode padding] # We pad input IDs/position IDs to the maximum draft length (token per request). # We're forced to do this because we cannot reallocate inputs over many graph runs. token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1 # Using ones instead of zeros prevents NaNs in e.g. Deepseek - self.input_ids = torch.ones((batch_size * token_per_request, ), - device=device, - dtype=torch.int32) - self.position_ids = torch.zeros((1, batch_size * token_per_request), - device=device, - dtype=torch.int32) + self.input_ids = torch.ones( + (batch_size * max_beam_width * token_per_request, ), + device=device, + dtype=torch.int32) + self.position_ids = torch.zeros( + (1, batch_size * max_beam_width * token_per_request), + device=device, + dtype=torch.int32) self.mrope_position_deltas = torch.zeros( (batch_size, 1), device=device, dtype=torch.int32) if use_mrope else None @@ -94,14 +80,11 @@ class DecodingCUDAGraphRunner: # internal states according to the docs: # https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics # This also lets us initialize states in the attn_metadata. - set_graph_capturing(True) - set_piecewise_cuda_graph_flag(False) - for _ in range(2): - forward_fn(inputs) - with torch.cuda.graph(self._graph, pool=pool): - output = forward_fn(inputs) - set_graph_capturing(False) - set_piecewise_cuda_graph_flag(True) + with with_multi_stream(True), piecewise_cuda_graph(False): + for _ in range(2): + forward_fn(inputs) + with torch.cuda.graph(self._graph, pool=pool): + output = forward_fn(inputs) # Mark weak ref here. The output tensor should be freed properly. self._output = make_weak_ref(output) return self._graph.pool() diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 5b68731fb9..17ba4983b7 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -11,6 +11,7 @@ from typing import Dict, Iterable, List, Optional, Tuple import torch from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.mapping import CpType from ..distributed import Distributed from .llm_request import (ExecutorRequest, LlmRequest, @@ -233,8 +234,7 @@ class ExecutorRequestQueue: def can_enqueue_request(self) -> bool: with self.enqueue_lock: - can_enqueue = self.active - return can_enqueue and self.dist.rank == 0 + return self.active and self.dist.rank == 0 def _fetch_and_process_requests( self, @@ -570,9 +570,9 @@ class ExecutorRequestQueue: cp_config = self.dist.cp_config if 'cp_type' in cp_config: cp_type = cp_config['cp_type'] - if cp_type == 'star_attention': + if cp_type == CpType.STAR: return self._merge_star_attention_requests(new_requests) - elif cp_type == 'ring_attention': + elif cp_type == CpType.RING: raise NotImplementedError("ring attention not implemented yet") else: raise NotImplementedError(f'unsupport cp type {cp_type}') diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 7d51af7ae1..22f24752c7 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -16,11 +16,19 @@ class GrammarMatcher(ABC): def accept_token(self, token_id: int) -> bool: pass + @abstractmethod + def rollback(self, num_tokens: int) -> None: + pass + @abstractmethod def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: pass + @abstractmethod + def is_terminated(self) -> bool: + pass + class GrammarMatcherFactory(ABC): @@ -39,15 +47,23 @@ class XGrammarMatcher(GrammarMatcher): def accept_token(self, token_id: int) -> bool: return self._matcher.accept_token(token_id) + def rollback(self, num_tokens: int) -> None: + self._matcher.rollback(num_tokens) + def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: self._matcher.fill_next_token_bitmask(next_token_bitmask, index) + def is_terminated(self) -> bool: + return self._matcher.is_terminated() + class XGrammarMatcherFactory(GrammarMatcherFactory): - def __init__(self, guided_decoding_config: GuidedDecodingConfig, - vocab_size_padded: int): + def __init__(self, + guided_decoding_config: GuidedDecodingConfig, + vocab_size_padded: int, + max_num_draft_tokens: int = 0): super().__init__() vocab_type = xgrammar.VocabType.RAW add_prefix_space = False @@ -72,6 +88,7 @@ class XGrammarMatcherFactory(GrammarMatcherFactory): cache_enabled=True, cache_limit_bytes=cache_limit_bytes, ) + self.max_num_draft_tokens = max_num_draft_tokens def create(self, guided_decoding_params: GuidedDecodingParams) -> XGrammarMatcher: @@ -106,20 +123,38 @@ class XGrammarMatcherFactory(GrammarMatcherFactory): case _: raise ValueError(f"Unsupported guide type: {guide_type}.") - matcher = xgrammar.GrammarMatcher(compiled_grammar) + matcher = xgrammar.GrammarMatcher( + compiled_grammar, max_rollback_tokens=self.max_num_draft_tokens) return XGrammarMatcher(matcher) class LLGuidanceMatcher(GrammarMatcher): - def __init__(self, matcher: llguidance.LLMatcher): + def __init__(self, matcher: llguidance.LLMatcher, eos_token: int): super().__init__() self._matcher = matcher + self._eos_token = eos_token + self._is_terminated = False def accept_token(self, token_id: int) -> bool: - result = self._matcher.consume_token(token_id) + if self._matcher.is_stopped(): + # Accept EOS token only if the matcher is stopped. + if token_id == self._eos_token: + self._is_terminated = True + return True + else: + return False + + num_accepted = self._matcher.try_consume_tokens([token_id]) + self._check_err() + return num_accepted > 0 + + def rollback(self, num_tokens: int) -> None: + if self._is_terminated: + self._is_terminated = False + num_tokens -= 1 + self._matcher.rollback(num_tokens) self._check_err() - return result def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: @@ -127,6 +162,9 @@ class LLGuidanceMatcher(GrammarMatcher): next_token_bitmask, index) self._check_err() + def is_terminated(self) -> bool: + return self._is_terminated + def _check_err(self) -> None: if self._matcher.is_error(): raise ValueError( @@ -181,4 +219,4 @@ class LLGuidanceMatcherFactory(GrammarMatcherFactory): if matcher.is_error(): raise ValueError(f"LLGuidance matcher error: {matcher.get_error()}") - return LLGuidanceMatcher(matcher) + return LLGuidanceMatcher(matcher, self._tokenizer.eos_token) diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index f1b21339b9..cc262699d8 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -5,19 +5,25 @@ import torch from ..._utils import nvtx_range from ...bindings.executor import GuidedDecodingConfig +from ...logger import logger from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory, LLGuidanceMatcherFactory, XGrammarMatcherFactory) +from .llm_request import LlmRequest from .scheduler import ScheduledRequests class GuidedDecoder: bitmask_dtype = torch.int32 - def __init__(self, guided_decoding_config: GuidedDecodingConfig, - max_num_sequences: int, vocab_size_padded: int): + def __init__(self, + guided_decoding_config: GuidedDecodingConfig, + max_num_sequences: int, + vocab_size_padded: int, + max_num_draft_tokens: int = 0): self.guided_decoding_backend = guided_decoding_config.backend self.max_num_sequences = max_num_sequences self.vocab_size_padded = vocab_size_padded + self.max_num_draft_tokens = max_num_draft_tokens self.grammar_matcher_factory: Optional[GrammarMatcherFactory] = None self.grammar_matchers: List[ @@ -25,71 +31,239 @@ class GuidedDecoder: if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR: self.grammar_matcher_factory = XGrammarMatcherFactory( - guided_decoding_config, vocab_size_padded) + guided_decoding_config, + vocab_size_padded, + max_num_draft_tokens=max_num_draft_tokens) elif self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE: self.grammar_matcher_factory = LLGuidanceMatcherFactory( guided_decoding_config, vocab_size_padded) else: raise ValueError( - f"invalid guided_decoding_backend: {self.guided_decoding_backend}" + f"Invalid guided decoding backend: {self.guided_decoding_backend}" ) + logger.info( + f"Guided decoder initialized with backend: {self.guided_decoding_backend}" + ) self.bitmask = torch.empty(self.max_num_sequences, + self.max_num_draft_tokens + 1, self.bitmask_size, dtype=self.bitmask_dtype, device='cuda') self.bitmask_host = torch.empty(self.max_num_sequences, + self.max_num_draft_tokens + 1, self.bitmask_size, dtype=self.bitmask_dtype, pin_memory=True) + # The number of tokens accepted by the grammar matcher in a build step. + self.num_advanced_tokens: List[int] = [0] * self.max_num_sequences + # The number of tokens with filled bitmask in a build step. + self.num_guided_tokens: List[int] = [0] * self.max_num_sequences + # The accumulated number of tokens accepted by the grammar matcher in a drafting loop. + self.num_advanced_draft_tokens: List[int] = [0] * self.max_num_sequences + # Whether is guided drafting is terminated because of unacceptable drafted tokens. + self.is_draft_terminated: List[bool] = [False] * self.max_num_sequences + self._stream = torch.cuda.Stream() @property def bitmask_size(self) -> int: return math.ceil(self.vocab_size_padded / 32) + def _require_matcher_init(self, llm_req: LlmRequest) -> bool: + if llm_req.guided_decoding_params is None: + return False + if llm_req.py_is_draft: + return False + # The request is in the last chunk of a context forward step. + return llm_req.is_context_init_state and llm_req.is_last_context_chunk + + def _require_matcher_advance(self, llm_req: LlmRequest) -> bool: + if llm_req.guided_decoding_params is None: + return False + if llm_req.py_is_draft: + if llm_req.is_context_init_state and llm_req.is_last_context_chunk: + return True + if llm_req.is_generation_in_progress_state: + return True + return False + # The request is in a generation forward step. + return llm_req.is_generation_in_progress_state + + @torch.inference_mode() @nvtx_range("GuidedDecoder.build") def build(self, scheduled_requests: ScheduledRequests) -> None: + """Build the bitmask for requests with guided decoding enabled. + + Specifically, this method: + - build and advance the grammar matcher for context and generation requests, respectively; + - call the grammar matcher to fill the bitmask on CPU; + - asynchronously copy the bitmask to GPU. + """ for llm_req in scheduled_requests.all_requests(): - if llm_req.guided_decoding_params is None: - continue - slot = llm_req.py_seq_slot - if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len: - self.grammar_matchers[ - slot] = self.grammar_matcher_factory.create( - llm_req.guided_decoding_params) + slot: int = llm_req.py_target_seq_slot if llm_req.py_is_draft else llm_req.py_seq_slot + self.num_advanced_tokens[slot] = 0 + self.num_guided_tokens[slot] = 0 - elif llm_req.is_generation_in_progress_state: - # The request is in a generation forward step. - # Currently, guided decoding does not support with beam search. - self.grammar_matchers[slot].accept_token( - llm_req.get_last_tokens(0)) - else: + matcher_init: bool = self._require_matcher_init(llm_req) + matcher_advance: bool = self._require_matcher_advance(llm_req) + if not (matcher_init or matcher_advance): continue - # Fill the bitmask on host and asynchorously copy to device. - self.grammar_matchers[slot].fill_next_token_bitmask( - self.bitmask_host, slot) - with torch.cuda.stream(self._stream): - self.bitmask[slot].copy_(self.bitmask_host[slot], - non_blocking=True) + if matcher_init: + matcher = self.grammar_matcher_factory.create( + llm_req.guided_decoding_params) + self.grammar_matchers[slot] = matcher + if matcher_advance: + matcher = self.grammar_matchers[slot] + # The last new token must be acceptable unless the matcher is terminated in a drafting loop. + if llm_req.py_is_draft and (matcher.is_terminated() + or self.is_draft_terminated[slot]): + continue + last_new_token = llm_req.get_last_tokens(0) + accepted = matcher.accept_token(last_new_token) + if not accepted: + if llm_req.py_is_draft: + self.is_draft_terminated[slot] = True + logger.debug( + f"Draft request {llm_req.py_request_id} failed to accept last new token: {last_new_token}." + ) + continue + # TODO: Make this an error response. + raise ValueError( + f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}." + ) + + self.num_advanced_tokens[slot] += 1 + if not matcher.is_terminated(): + matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0) + self.num_guided_tokens[slot] += 1 + # Process draft tokens + for i, tid in enumerate(llm_req.py_draft_tokens, 1): + accepted = matcher.accept_token(tid) + if not accepted: + break + self.num_advanced_tokens[slot] += 1 + if matcher.is_terminated(): + break + matcher.fill_next_token_bitmask(self.bitmask_host[slot], i) + self.num_guided_tokens[slot] += 1 + + if llm_req.py_is_draft: + assert len(llm_req.py_draft_tokens) == 0 + self.num_advanced_draft_tokens[ + slot] += self.num_advanced_tokens[slot] + + if (num_guided_tokens := self.num_guided_tokens[slot]) > 0: + with torch.cuda.stream(self._stream): + self.bitmask[slot, :num_guided_tokens].copy_( + self.bitmask_host[slot, :num_guided_tokens], + non_blocking=True) + + @torch.inference_mode() @nvtx_range("GuidedDecoder.execute") - def execute(self, scheduled_requests: ScheduledRequests, - logits: torch.Tensor) -> None: - assert logits.size(0) == len(scheduled_requests.context_requests) + len( - scheduled_requests.generation_requests) + def execute(self, + scheduled_requests: ScheduledRequests, + logits: torch.Tensor, + d2t: Optional[torch.Tensor] = None) -> None: + """Apply the bitmask to the corresponding logits for requests with guided decoding enabled. + + This method inplace modifies the logits tensor so that any tokens that violate the grammar constraints are masked out. + """ torch.cuda.current_stream().wait_stream(self._stream) + # TODO: Fuse index_copy and index_select to logits_bitmask. + if d2t is not None: + draft_logits = logits + d2t_mapping = d2t + torch.arange(d2t.size(0), device=d2t.device) + logits = torch.empty(draft_logits.size(0), + self.vocab_size_padded, + dtype=draft_logits.dtype, + device=draft_logits.device) + logits.index_copy_(-1, d2t_mapping, draft_logits) + batched_logits, batched_bitmask = [], [] - for i, llm_req in enumerate(scheduled_requests.all_requests()): - if llm_req.guided_decoding_params is None: - continue - if llm_req.is_context_init_state and not llm_req.is_last_context_chunk: - continue - batched_logits.append(logits[i]) - batched_bitmask.append(self.bitmask[llm_req.py_seq_slot]) + offset = 0 + for llm_req in scheduled_requests.all_requests(): + slot: int = llm_req.py_target_seq_slot if llm_req.py_is_draft else llm_req.py_seq_slot + for i in range(self.num_guided_tokens[slot]): + batched_logits.append(logits[offset + i]) + batched_bitmask.append(self.bitmask[slot, i]) + offset += len(llm_req.py_draft_tokens) + 1 + + # Dummy logits may exist for CUDA graph dummy requests. + assert offset <= logits.size(0) if len(batched_logits) > 0: torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask) + + if d2t is not None: + torch.index_select(logits, -1, d2t_mapping, out=draft_logits) + + @nvtx_range("GuidedDecoder.rollback_rejected_tokens") + def rollback_rejected_tokens(self, + scheduled_requests: ScheduledRequests) -> None: + """Rollback the grammar matcher for rejected tokens. + + This method should be called: + - after the verification (so that the accepted tokens are ready) and + - before the first guided decoding build of the next drafting loop. + """ + if self.max_num_draft_tokens <= 0: + return + + for llm_req in scheduled_requests.all_requests(): + assert not llm_req.py_is_draft + slot: int = llm_req.py_seq_slot + if self.num_advanced_tokens[slot] <= 0: + continue + # Rollback the grammar matcher to the last accepted token. + num_rollback_tokens = self.num_advanced_tokens[slot] - ( + 1 + llm_req.py_num_accepted_draft_tokens) + # TODO: Make this an error response. + if num_rollback_tokens < 0: + raise ValueError( + f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_draft_tokens={llm_req.py_num_accepted_draft_tokens}, num_rollback_tokens={num_rollback_tokens}" + ) + self.grammar_matchers[slot].rollback(num_rollback_tokens) + + @nvtx_range("GuidedDecoder.rollback_draft_tokens") + def rollback_draft_tokens(self, + scheduled_requests: ScheduledRequests) -> None: + """Rollback the grammar matcher for draft tokens. + + This method should be called: + - after the the drafting loop and + - before the guided decoding build of the target model. + """ + if self.max_num_draft_tokens <= 0: + return + + for llm_req in scheduled_requests.all_requests(): + assert not llm_req.py_is_draft + slot: int = llm_req.py_seq_slot + if self.num_advanced_draft_tokens[slot] <= 0: + continue + self.grammar_matchers[slot].rollback( + self.num_advanced_draft_tokens[slot]) + # Reset the drafting states. + self.num_advanced_draft_tokens[slot] = 0 + self.is_draft_terminated[slot] = False + + @nvtx_range("GuidedDecoder.init_disagg_gen_requests") + def init_disagg_gen_requests(self, + scheduled_requests: ScheduledRequests) -> None: + """Initialize the grammar matchers for disagg gen requests. + """ + for llm_req in scheduled_requests.generation_requests: + if llm_req.guided_decoding_params is None: + continue + assert not llm_req.py_is_draft + slot: int = llm_req.py_seq_slot + if llm_req.context_phase_params is not None and llm_req.py_decoding_iter == 1: + # The request is in the first generation forward step at the disagg gen instance. + self.grammar_matchers[ + slot] = self.grammar_matcher_factory.create( + llm_req.guided_decoding_params) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 0fb1f06e96..80f1153e50 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -250,6 +250,12 @@ class LlmResult: self._result = tensorrt_llm.bindings.executor.deserialize_result( self._result) + def get_result(self): + if tmp_res := tensorrt_llm.bindings.executor.deserialize_result( + self._result): + return tmp_res + return None + @dataclass class LlmResponse: @@ -281,10 +287,13 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): llm_request: Optional[ tensorrt_llm.bindings.internal.batch_manager.LlmRequest] = None, is_draft: bool = False, + seq_slot: Optional[int] = None, + target_seq_slot: Optional[int] = None, **kwargs): self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", None) + self.py_lora_path: str | None = kwargs.pop("py_lora_path", None) # Multimodal data self.py_multimodal_data = kwargs.pop("py_multimodal_data", None) if llm_request is not None: @@ -308,9 +317,12 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): self.py_orig_prompt_len = self.orig_prompt_len self.py_max_new_tokens = self.max_new_tokens self.py_batch_idx = None + self.py_draft_pages_allocated = 0 self.py_rewind_len = 0 self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens self.py_last_context_chunk = (None, None) + self.py_draft_logits = None + self.py_target_probs = None self.py_last_draft_tokens = None self.py_num_accepted_draft_tokens = 0 self.py_decoding_iter = 0 @@ -323,7 +335,11 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): self.py_return_generation_logits = return_generation_logits self.py_return_logits_device_memory = return_logits_device_memory self.py_is_draft = is_draft - self.py_seq_slot = None + # The request's sequence slot ID, an index between 0 (inclusive) and max_batch_size (exclusive). + self.py_seq_slot = seq_slot + # If the request is a draft request, target_seq_slot is the sequence slot ID of its target request. + self.py_target_seq_slot = target_seq_slot + self.use_draft_model = is_draft # TODO: remove this when use DynamicDecodeOp in pytorch flow. # currently, keep py_stop_words_list as python list, rather than tensor. @@ -490,6 +506,7 @@ def executor_request_to_llm_request( if executor_request.lora_config is not None else None, lora_config=executor_request.lora_config.config if executor_request.lora_config is not None else None, + py_lora_path=getattr(executor_request, "py_lora_path", None), mrope_rotary_cos_sin=mrope_rotary_cos_sin, mrope_position_deltas=mrope_position_deltas, lookahead_config=None, diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py new file mode 100644 index 0000000000..707fdf33fb --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union + +import torch + +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest +from tensorrt_llm._torch.pyexecutor.resource_manager import ( + BaseResourceManager, CacheTypeCpp, DataType, KvCacheConfigCpp, + KVCacheManager, get_pp_layers) +from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests +from tensorrt_llm.mapping import Mapping + + +class MambaCacheManager(BaseResourceManager): + + def __init__( + self, + d_state: int, + d_conv: int, + num_heads: int, + n_groups: int, + head_dim: int, + num_layers: int, + max_batch_size: int, + mapping: Mapping, + dtype: torch.dtype, + ssm_cache_dtype: torch.dtype, + layer_mask: Optional[List[bool]] = None, + ) -> None: + + self.mamba_ssm_cache_dtype = ssm_cache_dtype + + # get tp size + tp_size = mapping.tp_size + + # derive mamba parameters for conv and ssm states + d_inner = head_dim * num_heads + conv_dim = d_inner + 2 * n_groups * d_state + nheads = num_heads + + # check that can be partitioned + assert nheads % tp_size == 0, "nheads must be divisible by tp_size" + assert conv_dim % tp_size == 0, "conv_dim must be divisible by tp_size" + + # partition conv_dim and nheads + conv_dim = conv_dim // tp_size + nheads = nheads // tp_size + + # conv and ssm states device + device = torch.device("cuda") + + pp_layers, num_layers = get_pp_layers( + num_layers, + mapping, + layer_mask=layer_mask, + ) + num_local_layers = len(pp_layers) + self.mamba_layer_offsets = { + idx: offset + for offset, idx in enumerate(pp_layers) + } + + # mamba conv states + self.conv_states = torch.empty( + size=[ + num_local_layers, + max_batch_size, + conv_dim, + d_conv - 1, + ], + dtype=dtype, + device=device, + ) + + # mamba ssm states + self.ssm_states = torch.empty( + size=[ + num_local_layers, + max_batch_size, + nheads, + head_dim, + d_state, + ], + dtype=self.mamba_ssm_cache_dtype, + device=device, + ) + + # mamba cache available blocks + self.mamba_cache_free_blocks = [i for i in range(max_batch_size)] + + # mamba cache index, maps request_id -> state indices + self.mamba_cache_index: Dict[int, int] = {} + + # mamba cache state indices + self.state_indices: torch.Tensor = torch.arange(max_batch_size, + device=device, + dtype=torch.int32) + + def _prepare_mamba_cache_blocks(self, request_ids: List[int]): + state_indices = [] + for r in request_ids: + # cache hit + if r in self.mamba_cache_index: + state_indices.append(self.mamba_cache_index[r]) + # cache miss + else: + if len(self.mamba_cache_free_blocks) == 0: + raise Exception("run out of mamba cache blocks") + block = self.mamba_cache_free_blocks.pop() + self.mamba_cache_index[r] = block + state_indices.append(block) + self.state_indices[:len(state_indices)] = torch.as_tensor( + state_indices, dtype=torch.int32, device=self.ssm_states.device) + + def prepare_resources(self, scheduled_batch: ScheduledRequests): + context_ids = [ + i.py_request_id for i in scheduled_batch.context_requests + ] + generation_ids = [ + i.py_request_id for i in scheduled_batch.generation_requests + ] + request_ids = context_ids + generation_ids + self._prepare_mamba_cache_blocks(request_ids) + + def free_resources(self, request: LlmRequest): + request_id = request.py_request_id + if request_id in self.mamba_cache_index: + block = self.mamba_cache_index.pop(request_id) + self.mamba_cache_free_blocks.append(block) + + def get_state_indices(self) -> torch.Tensor: + return self.state_indices + + def get_conv_states(self, layer_idx: int) -> torch.Tensor: + layer_offset = self.mamba_layer_offsets[layer_idx] + return self.conv_states[layer_offset] + + def get_ssm_states(self, layer_idx: int) -> torch.Tensor: + layer_offset = self.mamba_layer_offsets[layer_idx] + return self.ssm_states[layer_offset] + + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + return self.mamba_ssm_cache_dtype + + def shutdown(self): + # release tensor memory, keeping python references as tensors + self.conv_states = torch.tensor([]) + self.ssm_states = torch.tensor([]) + self.state_indices = torch.tensor([]) + torch.cuda.empty_cache() + + +class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): + + def __init__( + self, + # mamba cache parameters + mamba_d_state: int, + mamba_d_conv: int, + mamba_num_heads: int, + mamba_n_groups: int, + mamba_head_dim: int, + mamba_num_layers: int, + mamba_layer_mask: List[bool], + mamba_cache_dtype: torch.dtype, + mamba_ssm_cache_dtype: torch.dtype, + + # kv cache parameters + kv_cache_config: KvCacheConfigCpp, + kv_cache_type: CacheTypeCpp, + *, + num_layers: int, + layer_mask: List[bool], + num_kv_heads: Union[int, List[Optional[int]]], + head_dim: int, + tokens_per_block: int, + # Note that max_seq_len is not necessarily equal to kv_cache_config.num_tokens. + # It's derived from the model's BuildConfig for consistency with the C++ backend. + max_seq_len: int, + max_batch_size: int, + mapping: Mapping, + dtype: DataType = DataType.HALF, + spec_config: Optional["DecodingBaseConfig"] = None, + ) -> None: + + # mamba hybrid cache requires block reuse to be disabled in KV cache config + assert not kv_cache_config.enable_block_reuse, "mamba hybrid cache requires block reuse to be disabled in KV cache config" + + # initialize mamba cache manager + MambaCacheManager.__init__( + self, + mamba_d_state, + mamba_d_conv, + mamba_num_heads, + mamba_n_groups, + mamba_head_dim, + mamba_num_layers, + max_batch_size, + mapping, + mamba_cache_dtype, + mamba_ssm_cache_dtype, + mamba_layer_mask, + ) + + # initialize kv cache manager + KVCacheManager.__init__( + self, + kv_cache_config, + kv_cache_type, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=dtype, + spec_config=spec_config, + layer_mask=layer_mask, + ) + + def prepare_resources(self, scheduled_batch: ScheduledRequests): + MambaCacheManager.prepare_resources(self, scheduled_batch) + KVCacheManager.prepare_resources(self, scheduled_batch) + + def free_resources(self, request: LlmRequest): + MambaCacheManager.free_resources(self, request) + KVCacheManager.free_resources(self, request) + + def shutdown(self): + MambaCacheManager.shutdown(self) + KVCacheManager.shutdown(self) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2d00cee05f..a34f03edb5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -22,12 +22,14 @@ from tensorrt_llm._torch.speculative import ( get_num_extra_kv_tokens, update_spec_config_from_model_config) from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, - torch_dtype_to_str, trace_func) + str_dtype_to_torch, torch_dtype_to_str, + trace_func) from tensorrt_llm.inputs.multimodal import (MultimodalParams, MultimodalRuntimeData) from tensorrt_llm.logger import logger -from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig -from tensorrt_llm.mapping import Mapping +from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.lora_manager import LoraModelConfig +from tensorrt_llm.mapping import CpType, Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 @@ -38,7 +40,7 @@ from ..attention_backend.utils import get_attention_backend from ..attention_backend.vanilla import VanillaAttentionMetadata from ..autotuner import AutoTuner, autotune from ..compilation.backend import Backend -from ..compilation.utils import set_enable_piecewise_cuda_graph_capture_flag +from ..compilation.utils import capture_piecewise_cuda_graph from ..distributed import MPIDist from ..distributed.communicator import init_pp_comm from ..expert_statistic import ExpertStatistic @@ -98,6 +100,16 @@ _KV_CACHE_MAP = { _VALID_KV_CACHE_DTYPES = ("fp8", "auto") +def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, + mamba_ssm_cache_dtype: str) -> None: + if mamba_ssm_cache_dtype == "auto": + mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype + else: + mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) + + config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + + def validate_and_set_kv_cache_quant(model_config: ModelConfig, pyt_kv_cache_dtype: str) -> QuantAlgo: logger.info( @@ -281,8 +293,6 @@ class PyTorchModelEngine(ModelEngine): self.enable_spec_decode = self.is_spec_decode self.is_draft_model = is_draft_model - self.in_warmup = False - self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( ) @@ -293,6 +303,8 @@ class PyTorchModelEngine(ModelEngine): checkpoint_loader=checkpoint_loader, attn_backend=attn_backend, moe_backend=pytorch_backend_config.moe_backend, + moe_disable_finalize_fusion=pytorch_backend_config. + moe_disable_finalize_fusion, load_format=pytorch_backend_config.load_format, max_num_tokens=max_num_tokens, moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens, @@ -321,18 +333,30 @@ class PyTorchModelEngine(ModelEngine): pytorch_backend_config.torch_compile_piecewise_cuda_graph and not self.enable_attention_dp) + piecewise_cuda_graph_num_tokens = ( + pytorch_backend_config.torch_compile_piecewise_cuda_graph_num_tokens + or pytorch_backend_config.cuda_graph_batch_sizes or []) + + self._piecewise_cuda_graph_num_tokens = [ + i for i in piecewise_cuda_graph_num_tokens + if i <= self.max_num_tokens + ] + try: + use_ub_for_nccl = ( + pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC" + and self._init_userbuffers(self.model.config.hidden_size)) if pytorch_backend_config.torch_compile_enabled: set_torch_compiling(True) - use_ub = pytorch_backend_config.torch_compile_enable_userbuffers and self._init_userbuffers( - self.model.config.hidden_size) + use_ub = not use_ub_for_nccl and ( + pytorch_backend_config.torch_compile_enable_userbuffers + and self._init_userbuffers(self.model.config.hidden_size)) self._torch_compile_backend = Backend( pytorch_backend_config.torch_compile_inductor_enabled, enable_userbuffers=use_ub, enable_piecewise_cuda_graph=self. _torch_compile_piecewise_cuda_graph, - cuda_graph_batch_sizes=pytorch_backend_config. - cuda_graph_batch_sizes, + capture_num_tokens=self._piecewise_cuda_graph_num_tokens, max_num_streams=pytorch_backend_config. torch_compile_max_num_streams) if isinstance(self.model, DecoderModelForCausalLM): @@ -355,6 +379,8 @@ class PyTorchModelEngine(ModelEngine): traceback.print_exception(Exception, e, e.__traceback__) raise e + self.is_warmup = False + self.attn_backend = get_attention_backend(attn_backend) if self.is_spec_decode: @@ -437,6 +463,10 @@ class PyTorchModelEngine(ModelEngine): else: self.cache_indirection_attention = None + @property + def runtime_draft_len(self): + return self.max_draft_len if self.enable_spec_decode else 0 + def set_lora_model_config(self, lora_target_modules: list[str], trtllm_modules_to_hf_modules: dict[str, str]): self.lora_model_config = LoraModelConfig( @@ -453,20 +483,47 @@ class PyTorchModelEngine(ModelEngine): 'type'] == 'mrope' except Exception: pass - logger.info(f"Detected use_mrope: {use_mrope}") + logger.debug(f"Detected use_mrope: {use_mrope}") return use_mrope + @property + def is_warmup(self): + return getattr(self, "_is_warmup", False) + + @is_warmup.setter + def is_warmup(self, value: bool): + self._is_warmup = value + + self.moe_load_balancer_iter_info = (not value, not value) + + @property + def moe_load_balancer_iter_info(self): + moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer', + None) + if moe_load_balancer is not None: + return moe_load_balancer.enable_statistic, moe_load_balancer.enable_update_weights + return False, False + + @moe_load_balancer_iter_info.setter + def moe_load_balancer_iter_info(self, value: Tuple[bool, bool]): + moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer', + None) + if moe_load_balancer is not None: + moe_load_balancer.set_iter_info(enable_statistic=value[0], + enable_update_weights=value[1]) + @property def use_beam_search(self): return self.max_beam_width > 1 @contextmanager def set_warmup_flag(self): - self.in_warmup = True + prev_is_warmup = self.is_warmup + self.is_warmup = True try: yield finally: - self.in_warmup = False + self.is_warmup = prev_is_warmup @staticmethod def with_warmup_flag(method): @@ -557,7 +614,7 @@ class PyTorchModelEngine(ModelEngine): list(range(batch_size)), [num_tokens_per_request] * batch_size if not is_gen else None, is_gen=is_gen, - max_num_draft_tokens=self.max_draft_len) + max_num_draft_tokens=self.runtime_draft_len) if spec_resource_manager is not None: spec_resource_manager.add_dummy_requests( @@ -576,7 +633,7 @@ class PyTorchModelEngine(ModelEngine): def get_autotune_warmup_request(): available_tokens = kv_cache_manager.get_num_available_tokens( - self.max_draft_len) + self.runtime_draft_len) num_tokens_per_request = min( min(available_tokens, self.max_seq_len - 1), self.max_num_tokens) @@ -610,14 +667,14 @@ class PyTorchModelEngine(ModelEngine): request_ids=list(range(full_len_request_num)), token_nums=[num_tokens_per_request] * full_len_request_num, is_gen=False, - max_num_draft_tokens=self.max_draft_len) + max_num_draft_tokens=self.runtime_draft_len) if remaining_tokens > 0: final_request = kv_cache_manager.add_dummy_requests( request_ids=[full_len_request_num], token_nums=[remaining_tokens], is_gen=False, - max_num_draft_tokens=self.max_draft_len) + max_num_draft_tokens=self.runtime_draft_len) requests += final_request @@ -644,120 +701,113 @@ class PyTorchModelEngine(ModelEngine): # TODO: current warmup_request is not suitable for star attention cp_type = self.mapping.cp_config.get('cp_type', None) - if cp_type == 'star_attention': + if cp_type == CpType.STAR: return - with contextlib.ExitStack() as stack: - if self._torch_compile_enabled: + if self._torch_compile_enabled: - def disable_optimization(backend: Backend): - # Disable torch.compile optimization and fallback to eager execution - backend.bypass_optimization() - # Disable piecewise CUDA graph capture since the capture run will produce wrong results - set_enable_piecewise_cuda_graph_capture_flag(False) - - stack.callback(disable_optimization, - self._torch_compile_backend) - - self._torch_compile_backend.enable_optimization() - - # Disable cuda graph capture here so that we can properly capture it later - with self.no_cuda_graph(): - available_tokens = kv_cache_manager.get_num_available_tokens( - self.max_draft_len) - warmup_batch_size = [1, self.batch_size // 2] - if self.batch_size < 2: - warmup_batch_size = [1] - for bs in warmup_batch_size: - for num_tokens_per_request in [ - 1, - min(self.max_num_tokens // max(bs, 1), - min(available_tokens, self.max_seq_len - 1)) - ]: - with release_batch( - get_torch_compile_warmup_request( - bs, num_tokens_per_request)) as batch: - if batch is None: - # No KV cache space! - continue - logger.info( - f"Run warmup for batch size={bs}, pure {'context' if num_tokens_per_request > 1 else 'generation'} phase" - ) - self.forward(batch, - new_tensors_device=None, - resource_manager=resource_manager) - torch.cuda.synchronize() - - if self.pytorch_backend_config.enable_autotuner: - with self.no_cuda_graph(), autotune(): - result = get_autotune_warmup_request() - with release_batch(result) as batch: - if batch is None: - # No KV cache space! - pass - else: + # Disable cuda graph capture here so that we can properly capture it later + with self.no_cuda_graph(): + available_tokens = kv_cache_manager.get_num_available_tokens( + self.runtime_draft_len) + warmup_batch_size = [1, self.batch_size // 2] + if self.batch_size < 2: + warmup_batch_size = [1] + for bs in warmup_batch_size: + for num_tokens_per_request in [ + 1, + min(self.max_num_tokens // max(bs, 1), + min(available_tokens, self.max_seq_len - 1)) + ]: + with release_batch( + get_torch_compile_warmup_request( + bs, num_tokens_per_request)) as batch: + if batch is None: + # No KV cache space! + continue + logger.info( + f"Run warmup for batch size={bs}, pure {'context' if num_tokens_per_request > 1 else 'generation'} phase" + ) self.forward(batch, new_tensors_device=None, resource_manager=resource_manager) torch.cuda.synchronize() - logger.info( - f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}" - ) - - AutoTuner.get().print_profiling_cache() - - if not (self._run_cuda_graphs - or self._torch_compile_piecewise_cuda_graph): - return - - logger.info( - f"Creating CUDA graph instances for {len(self._cuda_graph_batch_sizes)} batch sizes." - ) - # Reverse the order of the cuda graph batch sizes to make smaller batch size graph could reuse larger batch size graph memory - cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes, - reverse=True) - # Create CUDA graphs for different draft lengths - draft_lengths = [self.max_draft_len] - # For non-draft model, we also capture the CUDA graph instance for draft length 0, - # so that when we disable spec decode at runtime, we can still run the captured graph. - # Note that for one engine mode, we are not able to turn off spec decode at runtime. - if not self.is_draft_model and self.max_draft_len > 0 and not self.spec_config.spec_dec_mode.use_one_engine( - ): - draft_lengths.append(0) - - for bs in cuda_graph_batch_sizes: - if bs > self.batch_size: - # skip batch size larger than self.batch_size - continue - - for draft_len in draft_lengths: - with release_batch( - get_cuda_graph_warmup_request(bs, - draft_len)) as batch: - if batch is None: - # No KV cache space! - return - logger.info( - f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}" - ) - self.enable_spec_decode = draft_len > 0 or self.is_draft_model + if self.pytorch_backend_config.enable_autotuner: + with self.no_cuda_graph(), autotune(): + result = get_autotune_warmup_request() + with release_batch(result) as batch: + if batch is None: + # No KV cache space! + pass + else: self.forward(batch, new_tensors_device=None, resource_manager=resource_manager) torch.cuda.synchronize() - if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled: - for seq_lens in cuda_graph_batch_sizes: - set_enable_piecewise_cuda_graph_capture_flag(True) + logger.info( + f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}" + ) + + AutoTuner.get().print_profiling_cache() + + if not (self._run_cuda_graphs + or self._torch_compile_piecewise_cuda_graph): + return + + logger.info( + f"Creating CUDA graph instances for {len(self._cuda_graph_batch_sizes)} batch sizes." + ) + # Reverse the order of the cuda graph batch sizes to make smaller batch size graph could reuse larger batch size graph memory + cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes, + reverse=True) + # Create CUDA graphs for different draft lengths + draft_lengths = [self.max_draft_len] + # For non-draft model, we also capture the CUDA graph instance for draft length 0, + # so that when we disable spec decode at runtime, we can still run the captured graph. + # Note that for one engine mode, we are not able to turn off spec decode at runtime. + if (not self.is_draft_model and self.max_draft_len > 0 + and not self.spec_config.spec_dec_mode.use_one_engine() + # Assume that speculation is always on if the user didn't give us a max_concurrency + # value. This will save on memory. + and self.spec_config.max_concurrency is not None): + draft_lengths.append(0) + + for bs in cuda_graph_batch_sizes: + if bs > self.batch_size: + # skip batch size larger than self.batch_size + continue + + for draft_len in draft_lengths: + with release_batch(get_cuda_graph_warmup_request( + bs, draft_len)) as batch: + if batch is None: + # No KV cache space! + return + logger.info( + f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}" + ) + self.enable_spec_decode = draft_len > 0 or self.is_draft_model + self.forward(batch, + new_tensors_device=None, + resource_manager=resource_manager) + torch.cuda.synchronize() + + if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled: + piecewise_cuda_graph_num_tokens = sorted( + self._piecewise_cuda_graph_num_tokens, reverse=True) + + with capture_piecewise_cuda_graph(True): + for num_tokens in piecewise_cuda_graph_num_tokens: with self.no_cuda_graph(): with release_batch( get_torch_compile_warmup_request( - 1, seq_lens)) as batch: + 1, num_tokens)) as batch: logger.info( - f"Run piecewise CUDA graph warmup for seq_lens={seq_lens}" + f"Run piecewise CUDA graph warmup for num tokens={num_tokens}" ) - # self.model.mtp_worker.stored_input_ids = [] + for _ in range(3): self.forward(batch, new_tensors_device=None, @@ -768,7 +818,6 @@ class PyTorchModelEngine(ModelEngine): torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() - set_enable_piecewise_cuda_graph_capture_flag(False) # Set the value back to the original value self.enable_spec_decode = self.is_spec_decode @@ -842,8 +891,8 @@ class PyTorchModelEngine(ModelEngine): spec_resource_manager: Optional[BaseResourceManager] = None) -> int: can_run_cuda_graph = scheduled_requests.can_run_cuda_graph batch_size = scheduled_requests.batch_size - # The number of sequences in the batch is the number of prompts times the beam width. - new_batch_size = batch_size * self.max_beam_width + new_batch_size = batch_size + if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1: graph_batch_size = self.dist.tp_allgather( [can_run_cuda_graph, batch_size]) @@ -879,7 +928,7 @@ class PyTorchModelEngine(ModelEngine): self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests( cuda_graph_dummy_request_ids, is_gen=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self.runtime_draft_len, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width)[0] self.cuda_graph_dummy_request.is_cuda_graph_dummy = True @@ -977,8 +1026,8 @@ class PyTorchModelEngine(ModelEngine): self._cuda_graphs[batch_size] = {} self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner( - num_sequences_in_batch, "cuda", attn_metadata, spec_metadata, - self.use_mrope) + batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope, + self.max_beam_width) return self._cuda_graphs[batch_size][draft_len] def __del__(self) -> None: @@ -1015,6 +1064,9 @@ class PyTorchModelEngine(ModelEngine): validate_and_set_kv_cache_quant( config, self.pytorch_backend_config.kv_cache_dtype) + validate_and_set_mamba_ssm_cache_dtype( + config, self.pytorch_backend_config.mamba_ssm_cache_dtype) + num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) if num_layers > 0: config.pretrained_config.num_hidden_layers = num_layers @@ -1058,7 +1110,7 @@ class PyTorchModelEngine(ModelEngine): else: weights = checkpoint_loader.load_weights(checkpoint_dir) - weight_mapper = checkpoint_loader.get_initilized_weight_mapper( + weight_mapper = checkpoint_loader.get_initialized_weight_mapper( model, config) self._call_load_weights(model.load_weights, weights, weight_mapper) @@ -1223,11 +1275,13 @@ class PyTorchModelEngine(ModelEngine): multimodal_params = MultimodalParams( multimodal_data=request.py_multimodal_data, multimodal_runtime=py_multimodal_runtime) - multimodal_params.to_device("multimodal_data", - "cuda", - pin_memory=True) if multimodal_params.has_content(): + multimodal_params.to_device("multimodal_data", + "cuda", + pin_memory=True) + #re-assign the multimodal_data to the request after to_device for generation requests + request.py_multimodal_data = multimodal_params.multimodal_data multimodal_params_list.append(multimodal_params) request.py_batch_idx = request.py_seq_slot @@ -1261,10 +1315,12 @@ class PyTorchModelEngine(ModelEngine): multimodal_params = MultimodalParams( multimodal_data=request.py_multimodal_data) multimodal_params.strip_for_generation() - multimodal_params.to_device("multimodal_data", - "cuda", - pin_memory=True) if multimodal_params.has_content(): + multimodal_params.to_device("multimodal_data", + "cuda", + pin_memory=True) + # re-assign the multimodal_data to the request after strip_for_generation for another generation request, + request.py_multimodal_data = multimodal_params.multimodal_data multimodal_params_list.append(multimodal_params) extend_requests += extend_dummy_requests @@ -1306,7 +1362,7 @@ class PyTorchModelEngine(ModelEngine): gather_ids.extend( list( range(len(position_ids), - len(position_ids) + 1 + self.max_draft_len))) + len(position_ids) + 1 + self.runtime_draft_len))) position_ids.extend( list( range(past_seen_token_num, @@ -1322,23 +1378,23 @@ class PyTorchModelEngine(ModelEngine): # inputs # overlap scheduler can only support the speculative decoding # methods with a fixed number of draft tokens - sequence_lengths.append(1 + self.max_draft_len) + sequence_lengths.append(1 + self.runtime_draft_len) past_seen_token_num = request.max_beam_num_tokens - 1 - draft_lens.append(self.max_draft_len) + draft_lens.append(self.runtime_draft_len) gather_ids.extend( list( range(len(position_ids), - len(position_ids) + 1 + self.max_draft_len))) + len(position_ids) + 1 + self.runtime_draft_len))) position_ids.extend( list( - range(past_seen_token_num, - past_seen_token_num + 1 + self.max_draft_len))) + range(past_seen_token_num, past_seen_token_num + 1 + + self.runtime_draft_len))) # previous tensor previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * - (1 + self.max_draft_len)) + (1 + self.runtime_draft_len)) num_cached_tokens_per_seq.append(past_seen_token_num + - self.max_draft_len + 1) + self.runtime_draft_len + 1) prompt_lengths.append(request.py_prompt_len) request_ids.append(request.py_request_id) @@ -1372,8 +1428,11 @@ class PyTorchModelEngine(ModelEngine): gather_ids.append(len(position_ids) - 1) request_ids.append(request.py_request_id) - gen_request_seq_slots.append(request.py_seq_slot) request.py_batch_idx = request.py_seq_slot + # Do not add a gen_request_seq_slot for CUDA graph dummy requests + # to prevent access errors due to None values + if not request.is_cuda_graph_dummy: + gen_request_seq_slots.append(request.py_seq_slot) previous_batch_len = len(previous_batch_indices) @@ -1412,21 +1471,21 @@ class PyTorchModelEngine(ModelEngine): previous_slots = previous_seq_slots_device() # previous input ids previous_batch_tokens = previous_batch_len * ( - 1 + self.max_draft_len) + 1 + self.runtime_draft_len) new_tokens = new_tokens_device.transpose( 0, 1)[previous_slots, :].flatten() self.input_ids_cuda[num_tokens:num_tokens + previous_batch_tokens].copy_( new_tokens, non_blocking=True) # previous draft tokens - previous_batch_draft_tokens = previous_batch_len * self.max_draft_len + previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + previous_batch_draft_tokens].copy_( next_draft_tokens_device[ previous_slots, :].flatten(), non_blocking=True) # prepare data for the preprocess inputs - kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1 + kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1 previous_pos_indices_host = torch.tensor(previous_pos_indices, dtype=torch.int, pin_memory=True) @@ -1451,8 +1510,8 @@ class PyTorchModelEngine(ModelEngine): extend_dummy_requests) self.previous_pos_id_offsets_cuda[ (num_extend_reqeust_wo_dummy - previous_batch_len) * - (1 + self.max_draft_len):num_extend_reqeust_wo_dummy * - (1 + self.max_draft_len)].copy_( + (1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy * + (1 + self.runtime_draft_len)].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ 0:previous_batch_tokens]], non_blocking=True) @@ -1502,11 +1561,11 @@ class PyTorchModelEngine(ModelEngine): pin_memory=True, ) - num_generation_requests = len(scheduled_requests.generation_requests) + num_generation_requests = len(gen_request_seq_slots) # Cache indirection is only used for beam search on generation requests if self.use_beam_search and num_generation_requests > 0: # CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph - is_cuda_graph_during_warmup = self.in_warmup and attn_metadata.is_cuda_graph + is_cuda_graph_during_warmup = self.is_warmup and attn_metadata.is_cuda_graph if cache_indirection_buffer is not None: #Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i self.cache_indirection_attention[:num_generation_requests].copy_( @@ -2075,7 +2134,7 @@ class PyTorchModelEngine(ModelEngine): cache_indirection_buffer: Optional[torch.Tensor] = None): if self.mapping is not None and 'cp_type' in self.mapping.cp_config: cp_type = self.mapping.cp_config['cp_type'] - if 'star_attention' == cp_type: + if CpType.STAR == cp_type: return self._prepare_star_attention_inputs( scheduled_requests, kv_cache_manager, attn_metadata) else: @@ -2116,14 +2175,8 @@ class PyTorchModelEngine(ModelEngine): spec_resource_manager = None spec_metadata = None - moe_load_balancer = None - if hasattr(self, 'moe_load_balancer'): - moe_load_balancer = getattr(self, 'moe_load_balancer') - if not self.in_warmup: - moe_enable_statistic = True - moe_enable_update = True - moe_load_balancer.set_next_iter_info(moe_enable_statistic, - moe_enable_update) + moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer', + None) if kv_cache_manager is None: inputs, gather_ids = self._prepare_tp_inputs_no_cache( @@ -2232,12 +2285,12 @@ class PyTorchModelEngine(ModelEngine): # Disable UB for unsupported platforms if not ub.ub_supported(): return False - ub.initialize_userbuffers_manager(self.mapping.tp_size, - self.mapping.pp_size, - self.mapping.cp_size, - self.mapping.rank, - self.mapping.gpus_per_node, - hidden_size * self.max_num_tokens * 2) + use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC" + ub.initialize_userbuffers_manager( + self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size, + self.mapping.rank, self.mapping.gpus_per_node, + hidden_size * self.max_num_tokens * 2, use_nccl_symmetric) + return True def load_weights_from_target_model(self, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 03bb6c9a7f..8dbbe39abb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -17,7 +17,8 @@ try: except ImportError: from cuda import cudart -from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType +from tensorrt_llm._torch.pyexecutor.resource_manager import ( + ResourceManagerType, request_context) from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank, is_trace_enabled, nvtx_range, trace_func) @@ -30,6 +31,7 @@ from tensorrt_llm.bindings.executor import (DisServingRequestStats, from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, ReqIdsSet) from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import CpType from tensorrt_llm.runtime.generation import CUASSERT from ..distributed import Distributed @@ -159,7 +161,6 @@ class PyExecutor: self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes( PROFILE_START_STOP_ENV_VAR_NAME) self.gc_nvtx_watcher_handle = _gc_nvtx_watcher() - self.is_warmup = False # During warmup, we don't enable the profiler # related modules self.resource_manager = resource_manager @@ -218,9 +219,12 @@ class PyExecutor: self.inflight_req_ids = ReqIdsSet() + # During warmup, we don't enable the profiler + self.is_warmup = True self.model_engine.warmup(self.resource_manager) if self.draft_model_engine is not None: self.draft_model_engine.warmup(self.resource_manager) + self.is_warmup = False self.is_shutdown = False self.max_batch_size = max_batch_size @@ -278,16 +282,25 @@ class PyExecutor: finally: self._executor_loop_cleanup() + @property + def is_warmup(self) -> bool: + return getattr(self, "_is_warmup", False) + + @is_warmup.setter + def is_warmup(self, value: bool): + self._is_warmup = value + # Set warmup flag in model engine to trigger torch compile and avoid moe load balancer statistics update + self.model_engine.is_warmup = value + if self.draft_model_engine is not None: + self.draft_model_engine.is_warmup = value + def start_worker(self): - self.worker_lock.acquire() - try: + with self.worker_lock: if self.worker_started == False: self.worker_thread = threading.Thread( target=self._event_loop_wrapper, daemon=True) self.worker_thread.start() self.worker_started = True - finally: - self.worker_lock.release() def __enter__(self): return self @@ -364,13 +377,9 @@ class PyExecutor: return [] latest_stats = (IterationStats(), None) - try: - self.stats_lock.acquire() + with self.stats_lock: latest_stats = self.stats self.stats = [] - finally: - self.stats_lock.release() - return latest_stats def get_latest_kv_cache_events(self): @@ -408,6 +417,16 @@ class PyExecutor: it = -1 enabled = False start_time = None + + # These events are used to record the time of the previous batch. + # We need two set of the start-end events to record the time through + # a ping-pong way so that it works with overlap scheduler. + start_event_1 = None + end_event_1 = torch.cuda.Event(enable_timing=True) + start_event_2 = None + end_event_2 = torch.cuda.Event(enable_timing=True) + prev_device_step_time = None + torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None) profile_start_stop = os.environ.get(PROFILE_START_STOP_ENV_VAR_NAME, None) @@ -430,7 +449,7 @@ class PyExecutor: with_modules=True) def profile_step(): - nonlocal it, enabled, start_time + nonlocal it, enabled, start_time, start_event_1, end_event_1, start_event_2, end_event_2, prev_device_step_time if it in self.profile_stop_iters and not self.is_warmup: assert enabled, "Inconsistent CUDA profiling state" if enable_torch_trace: @@ -443,15 +462,34 @@ class PyExecutor: if start_time is not None and self.print_log and self.dist.rank == 0: end_time = time.time() + if it % 2 == 0: + end_event_1.record() + if start_event_2 is not None: + end_event_2.synchronize() + prev_device_step_time = start_event_2.elapsed_time( + end_event_2) + else: + end_event_2.record() + if start_event_1 is not None: + end_event_1.synchronize() + prev_device_step_time = start_event_1.elapsed_time( + end_event_1) + if prev_device_step_time is None: + prev_device_step_time = "N/A" # Handle first iteration + else: + prev_device_step_time = f"{prev_device_step_time}ms" + host_step_time = (end_time - start_time) * 1000 # milliseconds formatted_timestamp = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") logger.info( f"iter = {self.model_engine.iter_counter}, " f"global_rank = {self.global_rank}, " f"rank = {self.dist.rank}, " - f"currank_total_requests = {self.num_fetch_requests_cur_rank}/{self.num_fetch_requests}, " - f"elapsed_time = {end_time - start_time}s, " + f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" + f"{self.executor_request_queue.num_fetch_requests}, " + f"host_step_time = {host_step_time}ms, " + f"prev_device_step_time = {prev_device_step_time}, " f"timestamp = {formatted_timestamp}, " f"num_scheduled_requests: {self.num_scheduled_requests}, " f"states = {self.model_engine.iter_states}") @@ -466,6 +504,14 @@ class PyExecutor: logger.info(f"Profiling started at iteration {it}.") enabled = True start_time = time.time() + if it % 2 == 0: + if start_event_1 is None: + start_event_1 = torch.cuda.Event(enable_timing=True) + start_event_1.record() + else: + if start_event_2 is None: + start_event_2 = torch.cuda.Event(enable_timing=True) + start_event_2.record() try: yield profile_step @@ -605,11 +651,8 @@ class PyExecutor: stats: IterationStats, req_stats: Optional[List[RequestStats]] = None): - try: - self.stats_lock.acquire() + with self.stats_lock: self.stats.append((stats, req_stats)) - finally: - self.stats_lock.release() def _process_iter_stats(self, finished_requests: list[LlmRequest], active_requests: List[LlmRequest], @@ -749,6 +792,9 @@ class PyExecutor: if self._need_return_logits(scheduled_batch): logits_host = batch_outputs["logits"].to( "cpu", non_blocking=True) + if self.kv_cache_transceiver and self.guided_decoder: + self.guided_decoder.init_disagg_gen_requests( + scheduled_batch) self._execute_guided_decoder( scheduled_batch, batch_outputs['logits']) @@ -875,6 +921,10 @@ class PyExecutor: self.use_spec_decode = self.drafter.should_use_spec_decode( self.active_requests) self.model_engine.enable_spec_decode = self.use_spec_decode + # If speculation is off, this function sets py_draft_tokens to None + # for all active requests. If it's on, we initialize py_draft_tokens + # with dummy draft tokens to make the scheduler aware of the fact + # that speculation is about to happen. self._prepare_draft_requests() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( @@ -897,7 +947,8 @@ class PyExecutor: f'{len(scheduled_batch.generation_requests)} generation requests') return scheduled_batch, iter_stats - def _execute_guided_decoder(self, scheduled_batch, logits): + def _execute_guided_decoder(self, scheduled_batch: ScheduledRequests, + logits: torch.Tensor): if self.guided_decoder is not None: self.guided_decoder.build(scheduled_batch) self.guided_decoder.execute(scheduled_batch, logits) @@ -934,9 +985,19 @@ class PyExecutor: self._handle_first_token_response(scheduled_batch) self.resource_manager.prepare_resources(scheduled_batch) + + if self.kv_cache_transceiver and self.guided_decoder: + self.guided_decoder.init_disagg_gen_requests( + scheduled_batch) if self.drafter is not None and self.use_spec_decode: - self.drafter.prepare_draft_tokens( - scheduled_batch, self.resource_manager) + with request_context( + is_draft=True, + scheduled_requests=scheduled_batch): + if self.guided_decoder is not None: + self.guided_decoder.rollback_rejected_tokens( + scheduled_batch) + self.drafter.prepare_draft_tokens( + scheduled_batch, self.resource_manager) batch_outputs = self._forward_step(scheduled_batch) self._execute_guided_decoder(scheduled_batch, @@ -1055,6 +1116,9 @@ class PyExecutor: if self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) + if self.kv_cache_transceiver and self.guided_decoder: + self.guided_decoder.init_disagg_gen_requests( + scheduled_batch) self._execute_guided_decoder(scheduled_batch, batch_outputs['logits']) @@ -1291,7 +1355,6 @@ class PyExecutor: for resource_mgr_type in ( ResourceManagerType.KV_CACHE_MANAGER, - ResourceManagerType.SEQ_SLOT_MANAGER, ResourceManagerType.SPEC_RESOURCE_MANAGER, ResourceManagerType.DRAFT_KV_CACHE_MANAGER): if (resource_mgr_type in self.resource_manager.resource_managers @@ -1311,7 +1374,12 @@ class PyExecutor: if req.is_disagg_generation_transmission_complete: cache_trans_complete_requests.append(req) if len(cache_trans_complete_requests) > 0: - self._setup_sampler_step(cache_trans_complete_requests) + requests = ScheduledRequests() + requests.context_requests = cache_trans_complete_requests + self.resource_manager.resource_managers[ + ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources( + requests) + self._setup_sampler_step(requests) for req in scheduled_batch.generation_requests: if req.is_disagg_generation_transmission_complete: @@ -1382,7 +1450,7 @@ class PyExecutor: new_tensors_device: Optional[SampleStateTensors] = None): @nvtx_range( - f"[Executor] _forward_step {self.model_engine.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" + f"[Executor] _forward_step {self.model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" ) def forward(scheduled_requests, resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer): @@ -1445,7 +1513,7 @@ class PyExecutor: cp_config = self.dist.cp_config if 'cp_type' in cp_config: cp_type = cp_config['cp_type'] - if cp_type == 'star_attention': + if cp_type == CpType.STAR: self._update_request_states_star_attention(scheduled_requests) else: assert False, f'Unsupport cp_type {cp_type}' @@ -1465,7 +1533,7 @@ class PyExecutor: self._handle_errors(error_msg) @nvtx_range("_setup_sampler_step") - def _setup_sampler_step(self, requests): + def _setup_sampler_step(self, requests: ScheduledRequests): try: return self.sampler.setup_sampler_step(requests) except Exception as e: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index a0194bc6db..20e2cfea15 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -13,7 +13,7 @@ from tensorrt_llm._utils import get_sm_family, get_sm_version from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig from tensorrt_llm.logger import logger -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.quantization import QuantAlgo @@ -33,6 +33,7 @@ from .py_executor import PyExecutor class _ExecutorCreationStage(enum.Enum): SAMPLER = "Sampler" DRAFTER = "Drafter" + GUIDED_DECODER = "Guided decoder" INIT_KV_CACHE = "Initial KV cache (temporary for KV cache size estimation)" INIT_EXTRA_RESOURCES = "Additional executor resources (temporary for KV cache size estimation)" MODEL_EXTRA = "Model resources created during usage" @@ -169,6 +170,14 @@ def _mangle_executor_config(executor_config: ExecutorConfig): ) executor_config.enable_chunked_context = False + spec_config = executor_config.speculative_config + if not executor_config.pytorch_backend_config.disable_overlap_scheduler and spec_config is not None: + if not spec_config.spec_dec_mode.support_overlap_scheduler(): + logger.warning( + f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}" + ) + executor_config.pytorch_backend_config.disable_overlap_scheduler = True + def _get_mapping(executor_config: ExecutorConfig) -> Mapping: if executor_config.mapping is None: @@ -297,10 +306,10 @@ def create_py_executor( f"disable enable_block_reuse for KV cache quant algorithm: {kv_cache_quant_algo}" ) executor_config.kv_cache_config.enable_block_reuse = False - if executor_config.enable_chunked_context and not (get_sm_family() - == 100): + if executor_config.enable_chunked_context and not ( + get_sm_family() == 100 or get_sm_version() == 90): logger.warning( - "Chunked Prefill for MLA can only be enabled on SM100f, " + "Chunked Prefill for MLA can only be enabled on SM90/100f, " f"disable enable_block_reuse for SM{get_sm_version()}") executor_config.enable_chunked_context = False model_engine.attn_runtime_features.chunked_prefill = False @@ -326,20 +335,28 @@ def create_py_executor( else: ctx_chunk_config = None + with mem_monitor.observe_creation_stage( + _ExecutorCreationStage.GUIDED_DECODER): + guided_decoder: Optional[GuidedDecoder] = None + if executor_config.guided_decoding_config is not None: + if spec_config is not None and not has_spec_drafter: + raise ValueError( + "Guided decoding is only supported with speculative decoding that has a dedicated drafter (two-model engine)." + ) + if mapping.is_last_pp_rank(): + max_num_draft_tokens = 0 + if spec_config is not None: + max_num_draft_tokens = spec_config.max_draft_len + guided_decoder = GuidedDecoder( + executor_config.guided_decoding_config, + executor_config.max_batch_size, + model_engine.model.vocab_size_padded, + max_num_draft_tokens=max_num_draft_tokens) + with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER): sampler = instantiate_sampler(model_engine, executor_config, pytorch_backend_config, mapping) - - guided_decoder: Optional[GuidedDecoder] = None - if executor_config.guided_decoding_config is not None: - if spec_config is not None: - raise ValueError( - "Guided decoding is not supported with speculative decoding.") - if mapping.is_last_pp_rank(): - guided_decoder = GuidedDecoder( - executor_config.guided_decoding_config, - executor_config.max_batch_size, - model_engine.model.vocab_size_padded) + logger.info(f"Using Sampler: {type(sampler).__name__}") resources = {} estimating_kv_cache = False @@ -368,8 +385,11 @@ def create_py_executor( # Drafter for speculative decoding with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER): - drafter = get_spec_drafter(model_engine, draft_model_engine, sampler, - spec_resource_manager) + drafter = get_spec_drafter(model_engine, + draft_model_engine, + sampler, + spec_resource_manager=spec_resource_manager, + guided_decoder=guided_decoder) with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_EXTRA_RESOURCES diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 9f44649b49..9a5b42166d 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -10,11 +10,13 @@ import torch import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE +from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig from tensorrt_llm.sampling_params import SamplingParams -from ..._utils import binding_dtype_size, nvtx_range +from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range from ...logger import logger -from ...mapping import Mapping +from ...mapping import CpType, Mapping from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig, get_draft_token_length) from .scheduler import ScheduledRequests @@ -109,6 +111,33 @@ def get_pp_layers( return pp_layers, total_num_layers +def request_context(is_draft: bool, scheduled_requests: ScheduledRequests): + + class RequestContext: + + def __init__(self, is_draft: bool, + scheduled_requests: ScheduledRequests): + self.is_draft = is_draft + self.scheduled_requests = scheduled_requests + + def __enter__(self): + if not self.is_draft: + return + + for req in self.scheduled_requests.all_requests(): + req.use_draft_model = True + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.is_draft: + return + + # Clean up the state + for req in self.scheduled_requests.all_requests(): + req.use_draft_model = False + + return RequestContext(is_draft, scheduled_requests) + + class KVCacheManager(BaseResourceManager): def __init__( @@ -131,6 +160,7 @@ class KVCacheManager(BaseResourceManager): max_num_tokens: int = 8192, model_config: Optional[ModelConfig] = None, max_beam_width: int = 1, + is_draft: bool = False, ) -> None: self.mapping = mapping self.dtype = dtype @@ -141,6 +171,7 @@ class KVCacheManager(BaseResourceManager): spec_config=spec_config, layer_mask=layer_mask, ) + self.is_draft = is_draft self.num_local_layers = len(self.pp_layers) self.layer_offsets = { idx: offset @@ -196,6 +227,7 @@ class KVCacheManager(BaseResourceManager): from ..speculative import get_num_extra_kv_tokens self.num_extra_kv_tokens = get_num_extra_kv_tokens(spec_config) self.event_buffer_max_size = kv_cache_config.event_buffer_max_size + self.attention_dp_events_gather_period_ms = kv_cache_config.attention_dp_events_gather_period_ms self.max_num_tokens = max_num_tokens # Determine max_attention_window_vec @@ -299,8 +331,17 @@ class KVCacheManager(BaseResourceManager): 'copy_on_partial_reuse': kv_cache_config.copy_on_partial_reuse, } if self.event_buffer_max_size > 0: - kwargs['event_manager'] = KVCacheEventManagerCpp( - max_kv_event_entries=self.event_buffer_max_size) + if mapping.enable_attention_dp: + kwargs['event_manager'] = KVCacheEventManagerCpp( + max_kv_event_entries=self.event_buffer_max_size, + attention_dp_rank=mapping.rank, + attention_dp_size=mapping.world_size, + attention_dp_events_gather_period_ms=self. + attention_dp_events_gather_period_ms, + ) + else: + kwargs['event_manager'] = KVCacheEventManagerCpp( + max_kv_event_entries=self.event_buffer_max_size) self.impl = KVCacheManagerCpp(**kwargs) @@ -355,34 +396,36 @@ class KVCacheManager(BaseResourceManager): return need_blocks def prepare_resources(self, scheduled_batch: ScheduledRequests): - context_batch = scheduled_batch.context_requests - generation_batch = scheduled_batch.generation_requests - # allocate KV Cache - for req in context_batch: - req_beam_width = req.sampling_config.beam_width - if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[ - 'cp_type']: - if req.ctx_iters == 0: - seq_len = sum( - len(ctx_block) for ctx_block in req.ctx_blocks) - self.impl.add_sequence( - req.py_request_id, - seq_len + (len(req.query_id) if self.mapping.cp_rank - == self.mapping.cp_size - 1 else 0), - req_beam_width, req) - else: - if req.is_first_context_chunk: - self.impl.add_sequence(req.py_request_id, req.prompt_len, - req_beam_width, req) - for _ in range(self.num_extra_kv_tokens): - self.impl.add_token(req.py_request_id) - for _ in range(get_draft_token_length(req)): - self.impl.add_token(req.py_request_id) + with request_context(self.is_draft, scheduled_batch): + context_batch = scheduled_batch.context_requests + generation_batch = scheduled_batch.generation_requests + # allocate KV Cache + for req in context_batch: + req_beam_width = req.sampling_config.beam_width + if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[ + 'cp_type']: + if req.ctx_iters == 0: + seq_len = sum( + len(ctx_block) for ctx_block in req.ctx_blocks) + self.impl.add_sequence( + req.py_request_id, + seq_len + (len(req.query_id) if self.mapping.cp_rank + == self.mapping.cp_size - 1 else 0), + req_beam_width, req) + else: + if req.is_first_context_chunk: + self.impl.add_sequence(req.py_request_id, + req.prompt_len, req_beam_width, + req) + for _ in range(self.num_extra_kv_tokens): + self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) - for req in generation_batch: - self.impl.add_token(req.py_request_id) - for _ in range(get_draft_token_length(req)): + for req in generation_batch: self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) def add_dummy_requests( self, @@ -450,6 +493,10 @@ class KVCacheManager(BaseResourceManager): if request.py_rewind_len > 0: self.rewind_kv_cache(request, request.py_rewind_len) + # For context requests, we store the blocks for reuse. + for request in scheduled_batch.context_requests: + self.impl.store_context_blocks(request) + def free_resources(self, request: LlmRequest): self.impl.remove_sequence(request.py_request_id, request) @@ -884,218 +931,6 @@ class KVCacheManager(BaseResourceManager): return None -class MambaCacheManager(BaseResourceManager): - - def __init__( - self, - d_state: int, - d_conv: int, - num_heads: int, - n_groups: int, - head_dim: int, - num_layers: int, - max_batch_size: int, - mapping: Mapping, - dtype: torch.dtype, - layer_mask: Optional[List[bool]] = None, - ) -> None: - - # get tp size - tp_size = mapping.tp_size - - # derive mamba parameters for conv and ssm states - d_inner = head_dim * num_heads - conv_dim = d_inner + 2 * n_groups * d_state - nheads = num_heads - - # check that can be partitioned - assert nheads % tp_size == 0, "nheads must be divisible by tp_size" - assert conv_dim % tp_size == 0, "conv_dim must be divisible by tp_size" - - # partition conv_dim and nheads - conv_dim = conv_dim // tp_size - nheads = nheads // tp_size - - # conv and ssm states device - device = torch.device("cuda") - - pp_layers, num_layers = get_pp_layers( - num_layers, - mapping, - layer_mask=layer_mask, - ) - num_local_layers = len(pp_layers) - self.mamba_layer_offsets = { - idx: offset - for offset, idx in enumerate(pp_layers) - } - - # mamba conv states - self.conv_states = torch.empty( - size=[ - num_local_layers, - max_batch_size, - conv_dim, - d_conv - 1, - ], - dtype=dtype, - device=device, - ) - - # mamba ssm states - self.ssm_states = torch.empty( - size=[ - num_local_layers, - max_batch_size, - nheads, - head_dim, - d_state, - ], - dtype=dtype, - device=device, - ) - - # mamba cache available blocks - self.mamba_cache_free_blocks = [i for i in range(max_batch_size)] - - # mamba cache index, maps request_id -> state indices - self.mamba_cache_index: Dict[int, int] = {} - - # mamba cache state indices - self.state_indices: torch.Tensor = torch.arange(max_batch_size, - device=device, - dtype=torch.int32) - - def _prepare_mamba_cache_blocks(self, request_ids: List[int]): - state_indices = [] - for r in request_ids: - # cache hit - if r in self.mamba_cache_index: - state_indices.append(self.mamba_cache_index[r]) - # cache miss - else: - if len(self.mamba_cache_free_blocks) == 0: - raise Exception("run out of mamba cache blocks") - block = self.mamba_cache_free_blocks.pop() - self.mamba_cache_index[r] = block - state_indices.append(block) - self.state_indices[:len(state_indices)] = torch.as_tensor( - state_indices, dtype=torch.int32, device=self.ssm_states.device) - - def prepare_resources(self, scheduled_batch: ScheduledRequests): - context_ids = [ - i.py_request_id for i in scheduled_batch.context_requests - ] - generation_ids = [ - i.py_request_id for i in scheduled_batch.generation_requests - ] - request_ids = context_ids + generation_ids - self._prepare_mamba_cache_blocks(request_ids) - - def free_resources(self, request: LlmRequest): - request_id = request.py_request_id - if request_id in self.mamba_cache_index: - block = self.mamba_cache_index.pop(request_id) - self.mamba_cache_free_blocks.append(block) - - def get_state_indices(self) -> torch.Tensor: - return self.state_indices - - def get_conv_states(self, layer_idx: int) -> torch.Tensor: - layer_offset = self.mamba_layer_offsets[layer_idx] - return self.conv_states[layer_offset] - - def get_ssm_states(self, layer_idx: int) -> torch.Tensor: - layer_offset = self.mamba_layer_offsets[layer_idx] - return self.ssm_states[layer_offset] - - def shutdown(self): - # release tensor memory, keeping python references as tensors - self.conv_states = torch.tensor([]) - self.ssm_states = torch.tensor([]) - self.state_indices = torch.tensor([]) - torch.cuda.empty_cache() - - -class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): - - def __init__( - self, - # mamba cache parameters - mamba_d_state: int, - mamba_d_conv: int, - mamba_num_heads: int, - mamba_n_groups: int, - mamba_head_dim: int, - mamba_num_layers: int, - mamba_layer_mask: List[bool], - mamba_cache_dtype: torch.dtype, - # kv cache parameters - kv_cache_config: KvCacheConfigCpp, - kv_cache_type: CacheTypeCpp, - *, - num_layers: int, - layer_mask: List[bool], - num_kv_heads: Union[int, List[Optional[int]]], - head_dim: int, - tokens_per_block: int, - # Note that max_seq_len is not necessarily equal to kv_cache_config.num_tokens. - # It's derived from the model's BuildConfig for consistency with the C++ backend. - max_seq_len: int, - max_batch_size: int, - mapping: Mapping, - dtype: DataType = DataType.HALF, - spec_config: Optional["DecodingBaseConfig"] = None, - ) -> None: - - # mamba hybrid cache requires block reuse to be disabled in KV cache config - assert not kv_cache_config.enable_block_reuse, "mamba hybrid cache requires block reuse to be disabled in KV cache config" - - # initialize mamba cache manager - MambaCacheManager.__init__( - self, - mamba_d_state, - mamba_d_conv, - mamba_num_heads, - mamba_n_groups, - mamba_head_dim, - mamba_num_layers, - max_batch_size, - mapping, - mamba_cache_dtype, - mamba_layer_mask, - ) - - # initialize kv cache manager - KVCacheManager.__init__( - self, - kv_cache_config, - kv_cache_type, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - tokens_per_block=tokens_per_block, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - mapping=mapping, - dtype=dtype, - spec_config=spec_config, - layer_mask=layer_mask, - ) - - def prepare_resources(self, scheduled_batch: ScheduledRequests): - MambaCacheManager.prepare_resources(self, scheduled_batch) - KVCacheManager.prepare_resources(self, scheduled_batch) - - def free_resources(self, request: LlmRequest): - MambaCacheManager.free_resources(self, request) - KVCacheManager.free_resources(self, request) - - def shutdown(self): - MambaCacheManager.shutdown(self) - KVCacheManager.shutdown(self) - - class SlotManager: def __init__(self, max_num_requests: int): @@ -1170,6 +1005,7 @@ class PeftCacheManager(BaseResourceManager): def __init__(self, peft_cache_config: PeftCacheConfig, + lora_config: LoraConfig, model_config: ModelConfig, world_config: WorldConfig | None = None): import tensorrt_llm.bindings as _tb @@ -1200,8 +1036,36 @@ class PeftCacheManager(BaseResourceManager): model_config=model_config, world_config=world_config, buffer_manager=buffer_manager) + self._lora_config = lora_config + self._lora_model_config = LoraModelConfig( + lora_config.lora_target_modules, + lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size, + binding_to_str_dtype(model_config.data_type)) + self._lora_manager = LoraManager() def add_request_peft(self, request: LlmRequest): + if request.lora_task_id is not None: + is_task_cached = self.impl.is_task_cached(request.lora_task_id) + if is_task_cached: + # PeftCacheManager::addRequestPeft in CPP doesn't allow having only one of [config tensor, weights + # tensor] without the other. Since there's no need for any of them when the LoRA adapter is already + # cached, we can safely remove both from the request. + request.remove_lora_tensors() + elif request.lora_weights is None and request.py_lora_path: + self._lora_manager.load_from_ckpt( + [request.py_lora_path], + model_config=self._lora_model_config, + runtime_mapping=None, + uids=[request.lora_task_id], + ckpt_source=self._lora_config.lora_ckpt_source) + request.lora_weights = self._lora_manager.cpp_lora_weights[ + request.lora_task_id] + + # PeftCacheManager CPP implementation expects an extra dim at index 0 + if request.lora_weights is not None: + request.lora_weights = request.lora_weights.unsqueeze(0) + if request.lora_config is not None: + request.lora_config = request.lora_config.unsqueeze(0) self.impl.add_request_peft(request, True) def ensure_batch(self, @@ -1221,12 +1085,7 @@ class PeftCacheManager(BaseResourceManager): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests for req in context_batch: - if req.lora_weights is not None and req.lora_config is not None: - req.lora_weights = req.lora_weights.reshape( - [1] + list(req.lora_weights.shape)) - req.lora_config = req.lora_config.reshape( - [1] + list(req.lora_config.shape)) - self.impl.add_request_peft(req, True) + self.add_request_peft(req) py_lora_task_layer_module_configs = self.impl.ensure_batch( context_batch, generation_batch, False) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index daaac14c5a..2fa0e4e331 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import dataclass -from typing import Literal +from typing import Literal, Optional import torch @@ -97,7 +97,9 @@ class EarlyStopSampler(Sampler): request.py_result.append_context_logits(logits) -def top_k_sampling_batch(logits, top_k=50): +def top_k_sampling_batch(logits, + top_k=50, + generator: Optional[torch.Generator] = None): logits_dim = logits.dim() if logits_dim == 1: logits = logits.unsqueeze(0) @@ -105,30 +107,83 @@ def top_k_sampling_batch(logits, top_k=50): batch_size, vocab_size = logits.size() # get first top_k logits of each sample and their indices - values, indices = torch.topk(logits, top_k, dim=-1) - min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) + if top_k > 0: + values, indices = torch.topk(logits, top_k, dim=-1) + min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) - # set the logits who is less than first top_k logits to -inf - logits = torch.where(logits < min_values, - torch.full_like(logits, float('-inf')), logits) + # set the logits who is less than first top_k logits to -inf + logits = torch.where(logits < min_values, + torch.full_like(logits, float('-inf')), logits) # compute probability distribution softmax = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1) + next_tokens = torch.multinomial(softmax, num_samples=1, + generator=generator).squeeze(-1) return next_tokens, softmax -def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9): +def top_p_sampling_batch(logits: torch.Tensor, + top_p: float = 0.9, + temperature: float = 1.0, + generator: Optional[torch.Generator] = None): logits_dim = logits.dim() if logits_dim == 1: logits = logits.unsqueeze(0) assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + if temperature != 0: + logits = logits / max(temperature, 1e-5) + # sort the logits of each sample in descending order sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + # compute cumulative probability distribution of each sample + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), + dim=-1) + # get the location of top_p + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + # set the logits to -inf whose is outside top_p + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, float('-inf')) + + # compute probability distribution + softmax = torch.softmax(logits, dim=-1) + + # sample from the distribution and generate result of [batch_size, 1] + next_tokens = torch.multinomial(softmax, num_samples=1, + generator=generator).squeeze(-1) + return next_tokens, softmax + + +def top_k_top_p_sampling_batch(logits: torch.Tensor, + top_k: int, + top_p: float, + temperature: float = 1.0, + generator: Optional[torch.Generator] = None): + logits_dim = logits.dim() + if logits_dim == 1: + logits = logits.unsqueeze(0) + assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + if temperature != 0: + logits = logits / max(temperature, 1e-5) + batch_size, vocab_size = logits.size() + # get first top_k logits of each sample and their indices + if top_k > 0: + values, indices = torch.topk(logits, top_k, dim=-1) + min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) + + # set the logits who is less than first top_k logits to -inf + logits = torch.where(logits < min_values, + torch.full_like(logits, float('-inf')), logits) + + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + # compute cumulative probability distribution of each sample cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) @@ -147,7 +202,8 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9): softmax = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1) + next_tokens = torch.multinomial(softmax, num_samples=1, + generator=generator).squeeze(-1) return next_tokens, softmax @@ -157,17 +213,54 @@ def greedy_search_sampling_batch(logits): return next_tokens, softmax +def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor, + generator: torch.Generator, draft_tokens: list[int]): + + p = draft_probs[torch.arange(len(draft_tokens)), draft_tokens] + q = target_probs[:-1] + q = q[torch.arange(len(draft_tokens)), draft_tokens] + accept_probs = torch.minimum(torch.ones(()), q / p) + # Use deterministic random generation for multi-GPU consistency + rejected_indices = (torch.rand(accept_probs.shape, + generator=generator, + device=accept_probs.device) + > accept_probs).nonzero() + return rejected_indices + + +def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor, + generator: torch.Generator, num_accepted: int): + + last_draft = draft_probs[num_accepted] + last_target = target_probs[num_accepted] + new = last_target - last_draft + new = torch.where(new > 0, new, 0.0) + + new_token = torch.multinomial(new, num_samples=1, + generator=generator).squeeze(-1) + return new_token + + TopK = tuple[Literal["top_k"], int] -TopP = tuple[Literal["top_p"], float] +TopP = tuple[Literal["top_p"], float, float] +TopKTopP = tuple[Literal["top_k_top_p"], int, float, float] Greedy = tuple[Literal["greedy"], None] GREEDY: Greedy = ("greedy", None) Strategy = TopK | TopP | Greedy def request_strategy(request: LlmRequest) -> Strategy: + if request.sampling_config.top_k is not None and len( + request.sampling_config.top_k + ) > 0 and request.sampling_config.top_p is not None and len( + request.sampling_config.top_p) > 0: + return ("top_k_top_p", request.sampling_config.top_k[0], + request.sampling_config.top_p[0], + request.sampling_config.temperature[0]) if request.sampling_config.top_p is not None and len( request.sampling_config.top_p) > 0: - return ("top_p", request.sampling_config.top_p[0]) + return ("top_p", request.sampling_config.top_p[0], + request.sampling_config.temperature[0]) elif request.sampling_config.top_k is not None and len( request.sampling_config.top_k) > 0: return ("top_k", request.sampling_config.top_k[0]) @@ -179,12 +272,17 @@ def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]: return [request_strategy(req) for req in requests] -def sample(strategy: Strategy, logits: torch.Tensor): +def sample(strategy: Strategy, + logits: torch.Tensor, + generator: Optional[torch.Generator] = None): match strategy: case ("top_k", top_k): - return top_k_sampling_batch(logits, top_k) - case ("top_p", top_p): - return top_p_sampling_batch(logits, top_p) + return top_k_sampling_batch(logits, top_k, generator) + case ("top_p", top_p, temperature): + return top_p_sampling_batch(logits, top_p, temperature, generator) + case ("top_k_top_p", top_k, top_p, temperature): + return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature, + generator) case ("greedy", None): return greedy_search_sampling_batch(logits) @@ -230,9 +328,9 @@ class TorchSampler(Sampler): self.enable_mixed_sampler = args.enable_mixed_sampler self.max_tokens = args.max_draft_len + 1 assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1" - self.num_seq_slots = args.max_num_sequences + self.max_num_sequences = args.max_num_sequences - self.NEW_TOKENS_SHAPE = (self.max_tokens, self.num_seq_slots, + self.NEW_TOKENS_SHAPE = (self.max_tokens, self.max_num_sequences, self.MAX_BEAM_WIDTH) # AutoDeploy build creates the sampler in inference mode, # which would disallow in-place mutating of new_tokens. @@ -240,6 +338,25 @@ class TorchSampler(Sampler): with torch.inference_mode(False): self.store = self.create_store() + # Initialize seed for multi-GPU consistency + self._global_seed = 42 + self._generator = None + + def get_generator(self, device: torch.device) -> torch.Generator: + """Get a deterministic generator for the specified device. + + Args: + device: The device to create the generator on + + Returns: + A torch.Generator with the global seed set + """ + if self._generator is None: + # Fallback to a default seed if not set + self._generator = torch.Generator(device=device) + self._generator.manual_seed(self._global_seed) + return self._generator + def _meet_max_token_stop_criteria(self, request: LlmRequest): num_tokens = request.get_num_tokens(self.BEAM) return (num_tokens - request.py_orig_prompt_len @@ -301,13 +418,19 @@ class TorchSampler(Sampler): assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element" request.py_result.append_log_probs([token_log_probs]) - def process_draft_tokens(self, request: LlmRequest, - new_tokens: torch.Tensor, new_token: int) -> int: + def _process_draft_tokens_greedy(self, request: LlmRequest, + new_tokens: torch.Tensor) -> int: + new_token = add_token(request, new_tokens, beam=self.BEAM) + stop = self._handle_stop_criteria(request, new_token) + if stop or get_draft_token_length(request) == 0: + return 0 num_accepted = 0 + for draft_token in request.py_draft_tokens: if draft_token != new_token: # Reject. break + num_accepted += 1 new_token = add_token(request, new_tokens, @@ -317,6 +440,56 @@ class TorchSampler(Sampler): break return num_accepted + def _process_draft_tokens_rejection_sampling( + self, request: LlmRequest, new_tokens: torch.Tensor) -> int: + sampling_strategy = request_strategy(request) + generator = self.get_generator(request.py_draft_logits.device) + _, draft_probs = sample(sampling_strategy, + request.py_draft_logits[0], + generator=generator) + target_probs = request.py_target_probs + rejected_indices = get_rejected_indices(draft_probs, target_probs, + generator, + request.py_draft_tokens) + sample_last = True + stop = False + if rejected_indices.numel() == 0: + num_initially_accepted = get_draft_token_length(request) + sample_last = False + else: + num_initially_accepted = rejected_indices[0].item() + num_accepted = num_initially_accepted + for i in range(num_accepted): + new_token = request.py_draft_tokens[i] + new_tokens[i, request.seq_slot, self.BEAM] = new_token + request.add_new_token(new_token, self.BEAM) + stop = self._handle_stop_criteria(request, new_token) + if stop: + num_accepted = i + 1 + return num_accepted + if sample_last: + new_token = sample_rejected(draft_probs, target_probs, generator, + num_accepted) + new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token + request.add_new_token(new_token, self.BEAM) + stop = self._handle_stop_criteria(request, new_token) + else: + new_token = add_token(request, + new_tokens, + beam=self.BEAM, + step=num_accepted) + stop = self._handle_stop_criteria(request, new_token) + + return num_accepted + + def process_draft_tokens(self, request: LlmRequest, + new_tokens: torch.Tensor) -> int: + if request.py_draft_logits is None: + return self._process_draft_tokens_greedy(request, new_tokens) + else: + return self._process_draft_tokens_rejection_sampling( + request, new_tokens) + def update_requests(self, state: SampleState) -> None: assert isinstance(state, SampleState) if state.sampler_event: @@ -334,15 +507,12 @@ class TorchSampler(Sampler): for req in state.scheduled_requests.generation_requests: if req.state == LlmRequestState.GENERATION_COMPLETE: continue - new_token = add_token(req, new_tokens, beam=self.BEAM) - stop = self._handle_stop_criteria(req, new_token) processed = 1 - if not stop and get_draft_token_length(req) > 0: - num_accepted = self.process_draft_tokens( - req, new_tokens, new_token) + num_accepted = self.process_draft_tokens(req, new_tokens) + if get_draft_token_length(req) > 0: req.py_num_accepted_draft_tokens = num_accepted req.py_rewind_len = req.py_draft_pages_allocated - num_accepted - processed += num_accepted + processed += num_accepted self.handle_logits(req, state, beam=self.BEAM, count=processed) req.py_decoding_iter += 1 @@ -350,14 +520,14 @@ class TorchSampler(Sampler): """Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103""" if any(req.py_return_log_probs for req in requests): return torch.empty( - (self.num_seq_slots, self.MAX_BEAM_WIDTH, self.max_tokens), + (self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens), device="cpu", pin_memory=True) return None def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int): if any(req.py_return_generation_logits for req in requests): - return torch.empty((self.max_tokens, self.num_seq_slots, + return torch.empty((self.max_tokens, self.max_num_sequences, self.MAX_BEAM_WIDTH, vocab_size), device="cpu", pin_memory=True) @@ -458,8 +628,8 @@ class TorchSampler(Sampler): no_draft_tokens = len(requests) == sum_steps fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None - seq_slots = torch.as_tensor([r.py_seq_slot for r in requests]) - seq_slots = seq_slots.to(device="cuda", non_blocking=True) + seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests]) + seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) if fast_path: logits = raw_logits[:len(requests)] @@ -480,22 +650,22 @@ class TorchSampler(Sampler): batched_strategy = strategies[0] else: batched_strategy = None - + generator = self.get_generator(raw_logits.device) if batched_strategy is not None: logits = raw_logits[:sum_steps] # Collect steps per request for batched strategy steps_per_request = [ - 1 + len(req.py_draft_tokens) for req in requests + 1 + get_draft_token_length(req) for req in requests ] logits = self._apply_embedding_bias(logits, requests, steps_per_request) batched_next_tokens, batched_softmax = sample( - batched_strategy, logits) + batched_strategy, logits, generator) self.append_eagle3(batched_next_tokens, model_outputs) offset = 0 - for i, (strategy, slot, - steps) in enumerate(zip(strategies, seq_slots, num_steps)): + for i, (strategy, slot, steps, request) in enumerate( + zip(strategies, seq_slots_host, num_steps, requests)): input_slice = slice(offset, offset + steps) logits = raw_logits[input_slice] @@ -503,13 +673,15 @@ class TorchSampler(Sampler): if batched_next_tokens is None: logits = self._apply_embedding_bias(logits, [req]) - next_tokens, softmax = sample(strategy, logits) + next_tokens, softmax = sample(strategy, logits, generator) else: # Batched processing already applied bias, just use the results next_tokens = batched_next_tokens[input_slice] softmax = batched_softmax[input_slice] current_slice = slice(0, steps), slot, beam new_tokens[current_slice] = next_tokens + if request.py_draft_logits is not None: + request.py_target_probs = softmax.clone() if gen_logits_host is not None: gen_logits_host[current_slice].copy_(logits, non_blocking=True) if log_probs_host is not None: @@ -543,7 +715,8 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors): @dataclass(kw_only=True) class SampleStateTRTLLM(SampleState): - finalize_events: dict[str, CudaEvent] + finalize_events: dict[str, CudaEvent] | None = None + """`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`""" host: SampleStateTensorsHostTRTLLM @@ -573,7 +746,9 @@ class TRTLLMSampler(Sampler): self.decoding_config = self.executor_config.decoding_config if self.executor_config.decoding_config else DecodingConfig( decoding_mode) max_attn_window = self.executor_config.kv_cache_config.max_attention_window - self.max_attention_window = max_attn_window if max_attn_window is not None else executor_config.max_seq_len + self.max_attention_window = max( + max_attn_window + ) if max_attn_window is not None else executor_config.max_seq_len self.max_num_sequences = mapping.pp_size * self.executor_config.max_batch_size self.max_seq_idle_microseconds = 180 * 1000 * 1000 self.is_trt_overlap = not disable_overlap_scheduler @@ -607,13 +782,19 @@ class TRTLLMSampler(Sampler): self.MAX_DECODING_TOKENS, buffer_manager) for _ in range(self.num_micro_batches) ], + "sequence_lengths_host": + torch.empty(( + self.max_num_sequences, + self.executor_config.max_beam_width, + ), + dtype=torch.int), "decoder_state": DecoderState(), "decoding_input": [None] * self.num_micro_batches, } self.store["decoder_state"].setup( - max_batch_size=self.executor_config.max_batch_size, + max_num_sequences=self.max_num_sequences, max_beam_width=self.executor_config.max_beam_width, max_attention_window=self.max_attention_window, sink_token_length=0, @@ -629,7 +810,7 @@ class TRTLLMSampler(Sampler): self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"]) self.algs.decoder.setup( mode=self.decoding_mode, - max_batch_size=self.executor_config.max_batch_size, + max_num_sequences=self.max_num_sequences, max_beam_width=self.executor_config.max_beam_width, dtype=self.logits_datatype, model_config=self.model_config, @@ -647,11 +828,11 @@ class TRTLLMSampler(Sampler): def setup_sampler_step(self, requests): batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = self.algs.create_new_decoder_requests( self.model_config, self.world_config, self.decoding_config, - requests, self.store["buffer_manager"], self.logits_datatype, + requests.context_requests, self.logits_datatype, self.store["decoder_input_buffers"][self.micro_batch_idx], self.store["decoder_state"], self.store["cuda_stream"], self.algs.decoder.decoder_stream, self.executor_config.max_seq_len, - self.beam_width(requests)) + self.beam_width(requests.context_requests)) local_batch_size = len(batch_slots) if local_batch_size > 0: @@ -662,6 +843,16 @@ class TRTLLMSampler(Sampler): self.model_config.data_type, lookahead_prompt, lookahead_algo_configs) + adp = [ + r for r in requests.generation_requests if r.is_attention_dp_dummy + ] + batch_size = len(adp) + if batch_size == 0: + return + config = make_sampling_config([r.sampling_config for r in adp]) + slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32) + self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) + @staticmethod @torch.inference_mode() def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: @@ -696,7 +887,7 @@ class TRTLLMSampler(Sampler): "Beam search is not supported for multiple prompts and logprobs" ) - self.setup_sampler_step(scheduled_requests.context_requests) + self.setup_sampler_step(scheduled_requests) num_context_logits_prefix_sum = [0] prefix_sum = 0 @@ -947,7 +1138,7 @@ class TRTLLMSampler(Sampler): if finished_sum_host[seq_slot] == beam_width: request.state = LlmRequestState.GENERATION_COMPLETE for request in reqs: - if request.request_id in finalize_events: + if finalize_events is not None and request.request_id in finalize_events: self._post_process_request(request, state) def _finalize_request(self, request: LlmRequest, streaming: bool): diff --git a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py index c43c972641..a3f11e5642 100644 --- a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py @@ -16,6 +16,11 @@ class SeqSlotManager(BaseResourceManager): def prepare_resources(self, scheduled_batch: ScheduledRequests) -> None: for llm_req in scheduled_batch.all_requests(): + if llm_req.is_disagg_generation_init_state: + logger.info( + f"Skip assigning sequence slot for DISAGG_GENERATION_INIT request." + ) + continue if llm_req.seq_slot is None or llm_req.is_disagg_generation_transmission_complete: llm_req.seq_slot = self.slot_manager.add_slot( llm_req.request_id) diff --git a/tensorrt_llm/_torch/speculative/__init__.py b/tensorrt_llm/_torch/speculative/__init__.py index 6918b57390..0856cd46d9 100644 --- a/tensorrt_llm/_torch/speculative/__init__.py +++ b/tensorrt_llm/_torch/speculative/__init__.py @@ -1,3 +1,4 @@ +from .auto_heuristic import suggest_spec_config from .eagle3 import Eagle3SpecMetadata from .interface import SpecMetadata from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker @@ -23,4 +24,5 @@ __all__ = [ "get_spec_resource_manager", "get_spec_worker", "update_spec_config_from_model_config", + "suggest_spec_config", ] diff --git a/tensorrt_llm/_torch/speculative/auto_heuristic.py b/tensorrt_llm/_torch/speculative/auto_heuristic.py new file mode 100644 index 0000000000..907909beb8 --- /dev/null +++ b/tensorrt_llm/_torch/speculative/auto_heuristic.py @@ -0,0 +1,17 @@ +def suggest_spec_config(max_batch_size: int) -> "DecodingBaseConfig": + """ + Suggests a reasonable draft model free speculation scheme. + Used when the user specifies spec_mode == AUTO. + + For now, we always use an ngram scheme that gets disabled at + BS>=32. + """ + from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig + return NGramDecodingConfig( + max_draft_len=5 if max_batch_size <= 4 else 3, + max_matching_ngram_size=3 if max_batch_size <= 4 else 5, + max_concurrency=32, + is_keep_all=True, + is_use_oldest=True, + is_public_pool=True, + ) diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 9624193d45..82d816b800 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, final from ..pyexecutor.llm_request import LlmRequest from ..pyexecutor.resource_manager import ResourceManager @@ -9,6 +9,9 @@ from ..pyexecutor.scheduler import ScheduledRequests class Drafter(ABC): """Abstract base class for all drafter implementations.""" + def __init__(self, max_concurrency: Optional[int] = None) -> None: + self.max_concurrency = max_concurrency + @abstractmethod def prepare_draft_tokens( self, @@ -23,6 +26,13 @@ class Drafter(ABC): """ raise NotImplementedError + @final def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool: - """Check if spec decode should be used for the current iteration.""" + """ + You probably don't want to override this. ModelEngine + assumes that speculation is always on if max_concurrency + is not specified by the user's spec config. + """ + if self.max_concurrency is not None: + return len(requests) <= self.max_concurrency return True diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 72b598fc61..417becf12f 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -266,7 +266,8 @@ class Eagle3OneModelWorker(nn.Module): self.max_draft_len = self.spec_config.max_draft_len self.mapping = mapping - @torch.compile(options={"max-autotune": True}) + # Skip torch.compile for now since current Torch is not compatible with Triton 3.4 + # @torch.compile(options={"max-autotune": True}) def forward(self, input_ids, position_ids, hidden_states, logits, attn_metadata, spec_metadata, draft_model): batch_size = attn_metadata.num_seqs diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index d606073f00..f7cdd92a56 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -23,6 +23,9 @@ class SpeculativeDecodingMode(IntEnum): def is_mtp(self): return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE + def is_mtp_vanilla(self): + return self == SpeculativeDecodingMode.MTP + def is_mtp_eagle(self): return self == SpeculativeDecodingMode.MTP_EAGLE @@ -88,11 +91,13 @@ class SpeculativeDecodingMode(IntEnum): any spec dec mode that uses the SpecExecutor. """ - # Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell. - return ((self.is_eagle3() or self.is_draft_target()) - and not (isinstance(attention_backend, TrtllmAttention) - and get_sm_version() == 100) - ) or self.is_ngram() or self.is_user_provided() + if self.use_one_engine(): + # 1-model has separate logic for handling draft tokens + return False + + # The special XQA generation kernels only exist with the TRTLLM backend on blackwell. + return not issubclass(attention_backend, + TrtllmAttention) or get_sm_version() != 100 def attention_need_spec_dec_mode(self): """ diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 46d53bee31..7f11142c3f 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -8,10 +8,11 @@ import torch from tensorrt_llm._utils import nvtx_range from tensorrt_llm.logger import logger +from ..pyexecutor.guided_decoder import GuidedDecoder from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState, - SamplingConfig, get_draft_token_length) + get_draft_token_length) from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager -from ..pyexecutor.sampler import Sampler, SampleState +from ..pyexecutor.sampler import Sampler, SampleState, TorchSampler from ..pyexecutor.scheduler import ScheduledRequests from ..pyexecutor.seq_slot_manager import SeqSlotManager from .drafter import Drafter @@ -45,7 +46,10 @@ class ModelDrafter(Drafter): draft_seq_slot_manager: SeqSlotManager, sampler: Sampler, spec_resource_manager: Optional[BaseResourceManager] = None, + guided_decoder: Optional[GuidedDecoder] = None, ): + super().__init__(spec_config.max_concurrency) + # Validate required parameters if draft_model_engine is None: raise ValueError("draft_model_engine cannot be None") @@ -62,19 +66,24 @@ class ModelDrafter(Drafter): self.max_draft_tokens = max_draft_tokens # Sampling self.sampler = sampler + self._request_draft_logits = False + if isinstance(sampler, TorchSampler): + self._request_draft_logits = sampler.enable_mixed_sampler + self.guided_decoder = guided_decoder - def _create_draft_request(self, request_id: int, max_new_tokens: int, - input_tokens: Optional[List], - sampling_config: SamplingConfig, - return_perf_metrics: bool) -> LlmRequest: + def _create_draft_request(self, request: LlmRequest, + input_tokens: Optional[List]) -> LlmRequest: """Create a draft request with common parameters.""" - return LlmRequest(request_id=request_id, - max_new_tokens=max_new_tokens, - input_tokens=input_tokens, - sampling_config=sampling_config, - return_perf_metrics=return_perf_metrics, + return LlmRequest(input_tokens=input_tokens, + request_id=request.py_request_id, + max_new_tokens=request.py_max_new_tokens, + sampling_config=request.sampling_config, + guided_decoding_params=request.guided_decoding_params, + target_seq_slot=request.py_seq_slot, + return_perf_metrics=request.return_perf_metrics, is_streaming=False, - is_draft=True) + is_draft=True, + return_generation_logits=self._request_draft_logits) def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]: """Initialize draft token tracking for a request.""" @@ -92,11 +101,7 @@ class ModelDrafter(Drafter): def _create_context_request(self, request: LlmRequest, input_tokens: Any) -> LlmRequest: """Create a context request for first-time drafting.""" - new_request = self._create_draft_request(request.py_request_id, - request.py_max_new_tokens, - input_tokens, - request.sampling_config, - request.return_perf_metrics) + new_request = self._create_draft_request(request, input_tokens) begin_compute, end_compute = request.py_last_context_chunk if begin_compute is not None: @@ -107,13 +112,7 @@ class ModelDrafter(Drafter): def _create_generation_request(self, request: LlmRequest, input_tokens: Any) -> LlmRequest: """Create a generation request when no tokens were accepted.""" - new_request = self._create_draft_request(request.py_request_id, - request.py_max_new_tokens, - input_tokens[:-1], - request.sampling_config, - request.return_perf_metrics) - # Explicitly add the last token so get_last_tokens() returns the right value - new_request.add_new_token(input_tokens[-1], 0) + new_request = self._create_draft_request(request, input_tokens) new_request.state = LlmRequestState.GENERATION_IN_PROGRESS return new_request @@ -124,11 +123,7 @@ class ModelDrafter(Drafter): Create a chunked context request for accepted tokens. Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3) """ - new_request = self._create_draft_request(request.py_request_id, - request.py_max_new_tokens, - input_tokens, - request.sampling_config, - request.return_perf_metrics) + new_request = self._create_draft_request(request, input_tokens) new_request.context_chunk_size = num_accepted_tokens + 1 new_request.context_current_position = len( input_tokens) - num_accepted_tokens - 1 @@ -140,7 +135,7 @@ class ModelDrafter(Drafter): num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens( request) input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode, - request.get_tokens()[0]) + request.get_tokens(0)) # First time seeing this request - context request if request.max_beam_num_tokens - 1 == request.py_prompt_len: @@ -202,7 +197,7 @@ class ModelDrafter(Drafter): # We hit this path if we're doing chunked prefill. The target model processed # a prefill chunk on the last iteration. Now, we need to fill in the KV cache # for the draft model too. - all_tokens = request.get_tokens()[0] + all_tokens = request.get_tokens(0) input_tokens = get_draft_model_prompt( self.spec_config.spec_dec_mode, all_tokens) @@ -305,6 +300,8 @@ class ModelDrafter(Drafter): continue target_model_req.py_draft_tokens.append(req.get_last_tokens(0)) + if self._request_draft_logits: + target_model_req.py_draft_logits = req.py_result.generation_logits if req.state != LlmRequestState.GENERATION_COMPLETE and len( target_model_req.py_draft_tokens ) < target_model_req.py_draft_pages_allocated: @@ -323,6 +320,14 @@ class ModelDrafter(Drafter): req.py_draft_tokens.extend( 0 for _ in range(max_draft_tokens - num_draft_tokens)) + def _execute_guided_decoder(self, + scheduled_batch: ScheduledRequests, + logits: torch.Tensor, + d2t: Optional[torch.Tensor] = None): + if self.guided_decoder is not None: + self.guided_decoder.build(scheduled_batch) + self.guided_decoder.execute(scheduled_batch, logits, d2t=d2t) + @nvtx_range("prepare_draft_tokens") def prepare_draft_tokens( self, @@ -357,6 +362,9 @@ class ModelDrafter(Drafter): # Initial forward pass outputs = self._forward_draft_model(draft_batch, resource_manager) + self._execute_guided_decoder(draft_batch, + outputs['logits'], + d2t=outputs.get('d2t')) sample_state = self._sample_async(draft_batch, outputs) previous_batch = sample_state @@ -374,10 +382,14 @@ class ModelDrafter(Drafter): outputs = self._forward_draft_model(draft_batch, resource_manager, previous_batch) + if previous_batch is not None: + self._update_requests(previous_batch) + self._execute_guided_decoder(draft_batch, + outputs['logits'], + d2t=outputs.get('d2t')) sample_state = self._sample_async(draft_batch, outputs) self._update_request_states(draft_batch) if previous_batch is not None: - self._update_requests(previous_batch) new_requests = self._process_decoded_tokens( previous_batch.scheduled_requests, req_id_to_old_request) @@ -393,6 +405,9 @@ class ModelDrafter(Drafter): req_id_to_old_request) self._pad_to_max_draft_tokens(scheduled_requests) + if self.guided_decoder is not None: + self.guided_decoder.rollback_draft_tokens(scheduled_requests) + except Exception as e: traceback.print_exc() error_msg = str(e) diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 1772125bcb..2658ce539b 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -330,11 +330,9 @@ class MTPWorker(nn.Module): position_ids, hidden_states, logits, - lm_head, - embed_tokens, attn_metadata, spec_metadata, - mtp_layers, + draft_model, ): ''' Example: @@ -470,9 +468,10 @@ class MTPWorker(nn.Module): next_draft_tokens = [] last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 - for _, mtp_layer in enumerate(mtp_layers): - hidden_states = mtp_layer(embed_tokens=embed_tokens, **draft_inputs) - logits = mtp_layer.shared_head(hidden_states, lm_head, + for _, mtp_layer in enumerate(draft_model.mtp_layers): + hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens, + **draft_inputs) + logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head, attn_metadata).float() new_draft_token = self.draft_sampler(logits) next_draft_tokens.append(new_draft_token) @@ -517,11 +516,9 @@ class MTPWorker(nn.Module): position_ids, hidden_states, logits, - lm_head, - embed_tokens, attn_metadata, spec_metadata, - mtp_layers, + draft_model, ): batch_size = attn_metadata.num_seqs mtp_num_modules = self.spec_config.num_nextn_predict_layers @@ -1127,11 +1124,9 @@ class MTPEagleWorker(MTPWorker): position_ids, hidden_states, logits, - lm_head, - embed_tokens, attn_metadata, spec_metadata, - mtp_layers, + draft_model, ): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts @@ -1172,8 +1167,8 @@ class MTPEagleWorker(MTPWorker): next_draft_tokens = [] for i in range(self.mtp_num_modules): if i == 0: - hidden_states = mtp_layers[0]( - embed_tokens=embed_tokens, + hidden_states = draft_model.mtp_layers[0]( + embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata.all_rank_num_tokens, all_rank_max_num_tokens=spec_metadata. all_rank_max_num_tokens, @@ -1186,8 +1181,8 @@ class MTPEagleWorker(MTPWorker): gather_ids = torch.concat( [last_tokens_idx[:num_contexts], gather_ids_gen], dim=0) else: - hidden_states = mtp_layers[0]( - embed_tokens=embed_tokens, + hidden_states = draft_model.mtp_layers[0]( + embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata. subseq_all_rank_num_tokens, all_rank_max_num_tokens=max( @@ -1197,8 +1192,9 @@ class MTPEagleWorker(MTPWorker): **inputs) # All of the seq_len are 1, use batch_indices_cuda as gather_ids gather_ids = spec_metadata.batch_indices_cuda[:batch_size] - logits = mtp_layers[0].shared_head(hidden_states[gather_ids], - lm_head, attn_metadata, True) + logits = draft_model.mtp_layers[0].shared_head( + hidden_states[gather_ids], draft_model.lm_head, attn_metadata, + True) new_draft_token = self.draft_sampler(logits) hidden_states, position_ids = self.update_draft_tokens( diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 39267f5da2..1538897605 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -1,11 +1,12 @@ from itertools import chain +from typing import Optional from ordered_set import OrderedSet from tensorrt_llm.llmapi import NGramDecodingConfig from tensorrt_llm.logger import logger -from ..pyexecutor.llm_request import * +from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager from ..pyexecutor.scheduler import ScheduledRequests from .drafter import Drafter @@ -86,13 +87,13 @@ class NGramPoolManager(BaseResourceManager): self, prefix: list[int], request_id: int, - end_id: int, + padding_id: int, max_sequence_length: int, ): prefix_len = len(prefix) max_draft_token_length_this_step = max_sequence_length - 1 - prefix_len if max_draft_token_length_this_step <= 0: # No draft token is need if the prefix is long enough - return [end_id] + return [padding_id] if request_id not in self.start_index: # Extend start_index and pool for a new request self.start_index[request_id] = 0 if not self.is_public_pool: @@ -125,7 +126,7 @@ class NGramPoolManager(BaseResourceManager): pool[pattern].add(new_match) # Find match - draft_tokens = [end_id] # fallback value + draft_tokens = [padding_id] # fallback value for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0, -1): pattern = tuple(prefix[-size:]) @@ -167,6 +168,7 @@ class NGramDrafter(Drafter): spec_config: NGramDecodingConfig, ngram_pool_manager: NGramPoolManager = None, ): + super().__init__(spec_config.max_concurrency) assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len @@ -177,10 +179,6 @@ class NGramDrafter(Drafter): scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: - # Disable NGram speculative decoding auto heuristic for batch size > 32. - if self.spec_config.is_auto_heuristic and len( - scheduled_requests.all_requests()) > 32: - return # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft # before forward_step, so py_batch_idx is not assigned. @@ -190,17 +188,18 @@ class NGramDrafter(Drafter): (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ): # Add new token to a copy of the generated tokens to find new draft tokens - prefix = list(request.get_tokens()[0]) # Get a copy + prefix = list(request.get_tokens(0)) # Get a copy # Generate draft tokens draft_tokens = self.spec_resource_manager.get_draft_tokens( prefix, request.request_id, - request.py_end_id, - request.py_orig_prompt_len + request.py_max_new_tokens, + padding_id=0, + max_sequence_length=request.py_orig_prompt_len + + request.py_max_new_tokens, ) # Pad length to `self.max_draft_len` if len(draft_tokens) > 0: pad_length = self.max_draft_len - len(draft_tokens) - draft_tokens.extend([request.py_end_id] * pad_length) + draft_tokens.extend([0] * pad_length) request.py_draft_tokens = draft_tokens diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index e8db9d1f56..c4a4ccf7e3 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -1,7 +1,9 @@ -from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler -from tensorrt_llm._torch.speculative.interface import SpecMetadata +from typing import Optional +from ..pyexecutor.guided_decoder import GuidedDecoder +from ..pyexecutor.sampler import TorchSampler from ..pyexecutor.seq_slot_manager import SeqSlotManager +from ..speculative.interface import SpecMetadata from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata, Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3SpecMetadata) @@ -114,8 +116,11 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}") -def get_spec_drafter(model_engine, draft_model_engine, sampler, - spec_resource_manager): +def get_spec_drafter(model_engine, + draft_model_engine, + sampler, + spec_resource_manager, + guided_decoder: Optional[GuidedDecoder] = None): spec_config = model_engine.spec_config if spec_config is None: return None @@ -126,10 +131,13 @@ def get_spec_drafter(model_engine, draft_model_engine, sampler, max_num_requests = model_engine.batch_size if spec_config.spec_dec_mode.is_draft_target( ) or spec_config.spec_dec_mode.is_eagle3(): - return ModelDrafter(spec_config, draft_model_engine, + return ModelDrafter(spec_config, + draft_model_engine, spec_config.max_draft_len, - SeqSlotManager(max_num_requests), sampler, - spec_resource_manager) + SeqSlotManager(max_num_requests), + sampler, + spec_resource_manager=spec_resource_manager, + guided_decoder=guided_decoder) if spec_config.spec_dec_mode.is_ngram(): return NGramDrafter(spec_config, spec_resource_manager) @@ -146,11 +154,12 @@ def get_num_spec_layers(spec_config): def get_spec_worker(spec_config, model_config, mapping): - if spec_config.spec_dec_mode.is_mtp(): + spec_dec_mode = spec_config.spec_dec_mode + if spec_dec_mode.is_mtp_vanilla(): return MTPWorker(spec_config, model_config) - if spec_config.spec_dec_mode.is_mtp_eagle(): + if spec_dec_mode.is_mtp_eagle(): return MTPEagleWorker(spec_config, model_config) - if spec_config.spec_dec_mode.is_eagle3_one_model(): + if spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelWorker(spec_config, mapping) return None diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 15f8e634a5..4068ad44a6 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -12,7 +12,12 @@ from tensorrt_llm.quantization.utils import fp4_utils is_torch_compiling_flag = False -aux_stream_name_list = ['Attention', 'MoeShared', 'MoeChunkingOverlap'] +aux_stream_name_list = [ + 'Attention', + 'MoeShared', + 'MoeChunkingOverlap', + 'MoeBalancer', +] AuxStreamType = Enum( 'AuxStreamType', aux_stream_name_list, @@ -120,7 +125,7 @@ def swizzle_sf(sf: torch.Tensor, """ sf_cols = ceil_div(cols, scaling_vector_size) sf = sf.view(-1, rows, sf_cols) - return torch.ops.trtllm.nvfp4_block_scale_interleave(sf) + return torch.ops.trtllm.block_scale_interleave(sf) def unswizzle_sf(sf: torch.Tensor, @@ -138,8 +143,7 @@ def unswizzle_sf(sf: torch.Tensor, """ sf_cols = ceil_div(cols, scaling_vector_size) sf = sf.view(-1, rows, sf_cols) - return torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(sf).view( - -1, sf_cols) + return torch.ops.trtllm.block_scale_interleave_reverse(sf).view(-1, sf_cols) @torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=()) @@ -261,3 +265,13 @@ def set_piecewise_cuda_graph_flag(enable: bool): def get_piecewise_cuda_graph_flag() -> bool: global _enable_piecewise_cuda_graph return _enable_piecewise_cuda_graph + + +@contextlib.contextmanager +def piecewise_cuda_graph(enable: bool): + prev_enable = get_piecewise_cuda_graph_flag() + set_piecewise_cuda_graph_flag(enable) + try: + yield + finally: + set_piecewise_cuda_graph_flag(prev_enable) diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 0098983c6f..41c3d28a5b 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -20,6 +20,7 @@ import linecache import math import os import struct +import tempfile import trace import weakref from contextlib import contextmanager @@ -180,6 +181,7 @@ _str_to_binding_dtype_dict = dict( bool=DataType.BOOL, fp8=DataType.FP8, ) +_binding_to_str_dtype = {v: k for k, v in _str_to_binding_dtype_dict.items()} _binding_dtype_size = { DataType.INT64: 8, @@ -194,6 +196,12 @@ _binding_dtype_size = { } +def binding_to_str_dtype(binding_dtype) -> str: + ret = _binding_to_str_dtype.get(binding_dtype) + assert ret is not None, f'Unsupported binding dtype: {binding_dtype}' + return ret + + def binding_dtype_size(dtype: DataType): return _binding_dtype_size[dtype] @@ -1022,11 +1030,15 @@ class KVCacheEventSerializer: if event_serialize_func is None: raise ValueError(f"Unknown KVCache event data type: {event_type}") - return { + json_str = { "event_id": event.event_id, "data": event_serialize_func(event.data), - "window_size": event.window_size + "window_size": event.window_size, } + if event.attention_dp_rank is not None: + json_str["attention_dp_rank"] = event.attention_dp_rank + + return json_str @staticmethod def _created_to_json(data): @@ -1109,3 +1121,17 @@ def is_multi_device_enable(): the number of devices """ return local_mpi_size() > 1 + + +def set_prometheus_multiproc_dir() -> object: + # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266 + global prometheus_multiproc_dir + if "PROMETHEUS_MULTIPROC_DIR" in os.environ: + logger.info("User set PROMETHEUS_MULTIPROC_DIR detected.") + prometheus_multiproc_dir = tempfile.TemporaryDirectory( + dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]) + else: + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + logger.info( + f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") diff --git a/tensorrt_llm/bench/benchmark/low_latency.py b/tensorrt_llm/bench/benchmark/low_latency.py index af86fb2b1e..ad200af9c6 100644 --- a/tensorrt_llm/bench/benchmark/low_latency.py +++ b/tensorrt_llm/bench/benchmark/low_latency.py @@ -25,7 +25,7 @@ from tensorrt_llm.llmapi import CapacitySchedulerPolicy from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode # isort: off -from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS +from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, update_sampler_args_with_extra_options, ALL_SUPPORTED_BACKENDS # isort: on from tensorrt_llm.bench.utils.data import (create_dataset_from_stream, initialize_tokenizer, @@ -56,6 +56,12 @@ from tensorrt_llm.sampling_params import SamplingParams default=.90, help="The percentage of memory to use for KV Cache after model load.", ) +@optgroup.option( + "--mamba_ssm_cache_dtype", + type=click.Choice(["auto", "float16", "bfloat16", "float32"]), + default="auto", + help="Data type for Mamba SSM cache. If 'auto', inferred from model config.", +) @optgroup.option( "--max_seq_len", type=int, @@ -129,6 +135,13 @@ from tensorrt_llm.sampling_params import SamplingParams default=1, help="Number of search beams.", ) +@optgroup.option("--sampler_options", + type=click.Path(exists=True, + readable=True, + path_type=Path, + resolve_path=True), + default=None, + help="Path to a YAML file that sets sampler options.") @optgroup.option( "--concurrency", type=int, @@ -320,12 +333,16 @@ def latency_command( eos_id = tokenizer.eos_token_id if not ignore_eos else -1 pad_id = tokenizer.pad_token_id if not ignore_eos else -1 - sampling_params = SamplingParams( - end_id=eos_id, - pad_id=pad_id, - n=beam_width, - use_beam_search=beam_width > 1, - ) + sampler_args = { + "end_id": eos_id, + "pad_id": pad_id, + "n": beam_width, + "use_beam_search": beam_width > 1 + } + sampler_args = update_sampler_args_with_extra_options( + sampler_args, params.pop("sampler_options")) + sampling_params = SamplingParams(**sampler_args) + post_proc_params = None # No detokenization # Perform warmup if requested. diff --git a/tensorrt_llm/bench/benchmark/throughput.py b/tensorrt_llm/bench/benchmark/throughput.py index 8b83c85d51..57c86ac0f3 100755 --- a/tensorrt_llm/bench/benchmark/throughput.py +++ b/tensorrt_llm/bench/benchmark/throughput.py @@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter from tensorrt_llm.bench.build.build import get_model_config +from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir # isort: off from tensorrt_llm.bench.benchmark.utils.general import ( @@ -21,7 +22,8 @@ from tensorrt_llm.bench.benchmark.utils.general import ( from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm._tensorrt_engine import LLM from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM -from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset +from tensorrt_llm.bench.benchmark.utils.general import ( + generate_warmup_dataset, update_sampler_args_with_extra_options) from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment from tensorrt_llm.bench.dataclasses.reporting import ReportUtility @@ -49,6 +51,16 @@ from tensorrt_llm.sampling_params import SamplingParams type=click.Choice(ALL_SUPPORTED_BACKENDS), default="pytorch", help="The backend to use when running benchmarking.") +@optgroup.option( + "--custom_module_dirs", + type=click.Path(exists=True, + readable=True, + path_type=Path, + resolve_path=True), + default=None, + multiple=True, + help="Paths to custom module directories to import.", +) @optgroup.option( "--extra_llm_api_options", type=str, @@ -56,6 +68,13 @@ from tensorrt_llm.sampling_params import SamplingParams help= "Path to a YAML file that overwrites the parameters specified by trtllm-bench." ) +@optgroup.option("--sampler_options", + type=click.Path(exists=True, + readable=True, + path_type=Path, + resolve_path=True), + default=None, + help="Path to a YAML file that sets sampler options.") @optgroup.option( "--max_batch_size", type=int, @@ -84,6 +103,12 @@ from tensorrt_llm.sampling_params import SamplingParams default=.90, help="The percentage of memory to use for KV Cache after model load.", ) +@optgroup.option( + "--mamba_ssm_cache_dtype", + type=click.Choice(["auto", "float16", "bfloat16", "float32"]), + default="auto", + help="Data type for Mamba SSM cache. If 'auto', inferred from model config.", +) @optgroup.group( "Engine Input Configuration", help="Input configuration for driving the engine.", @@ -98,6 +123,16 @@ from tensorrt_llm.sampling_params import SamplingParams required=False, help="Pass in a dataset file for parsing instead of stdin.", ) +# For text models, tokenizer initialization is not needed when loading the model since the dataset is already tokenized. +# For this reason, we skip tokenizer initialization by default. +# However, for VLM models, tokenizer initialization is needed inside the model since the dataset contains texts and +# raw media data. We cannot skip tokenizer initialization in this case. +@optgroup.option( + "--no_skip_tokenizer_init", + is_flag=True, + default=False, + help="Do not skip tokenizer initialization when loading the model.", +) @optgroup.option( "--eos_id", type=int, @@ -112,6 +147,18 @@ from tensorrt_llm.sampling_params import SamplingParams default=None, help="Modality of the multimodal requests.", ) +@optgroup.option( + "--image_data_format", + type=click.Choice(["pt", "pil"]), + default="pt", + help="Format of the image data for multimodal models.", +) +@optgroup.option( + "--data_device", + type=click.Choice(["cuda", "cpu"]), + default="cuda", + help="Device to load the multimodal data on.", +) @optgroup.option( "--max_input_len", type=int, @@ -256,7 +303,17 @@ def throughput_command( logger.info("Preparing to run throughput benchmark...") # Parameters from CLI # Model, experiment, and engine params + custom_module_dirs: list[Path] = params.pop("custom_module_dirs", []) + for custom_module_dir in custom_module_dirs: + try: + import_custom_module_from_dir(custom_module_dir) + except Exception as e: + logger.error( + f"Failed to import custom module from {custom_module_dir}: {e}") + raise e + dataset_path: Path = params.get("dataset") + no_skip_tokenizer_init: bool = params.get("no_skip_tokenizer_init", False) eos_id: int = params.get("eos_id") warmup: int = params.get("warmup") num_requests: int = params.get("num_requests") @@ -268,6 +325,8 @@ def throughput_command( backend: str = params.get("backend") modality: str = params.get("modality") max_input_len: int = params.get("max_input_len") + image_data_format: str = params.get("image_data_format", "pt") + data_device: str = params.get("data_device", "cpu") model_type = get_model_config(model, checkpoint_path).model_type # Reporting options @@ -280,7 +339,7 @@ def throughput_command( # Runtime kwargs and option tracking. kwargs = {} - # Initialize the HF tokenizer for the specified model. + # Initialize the HF tokenizer for the specified model. This is only used for data preparation. tokenizer = initialize_tokenizer(checkpoint_path) # Dataset Loading and Preparation @@ -292,6 +351,8 @@ def throughput_command( model_dir=checkpoint_path, model_type=model_type, modality=modality, + image_data_format=image_data_format, + data_device=data_device, max_input_seq_len_for_multimodal=max_input_len) metadata.dataset_path = dataset_path params["target_input_len"] = params.get( @@ -386,6 +447,7 @@ def throughput_command( logger.info("Setting up throughput benchmark.") kwargs = kwargs | runtime_config.get_llm_args() kwargs['backend'] = backend + kwargs['skip_tokenizer_init'] = not no_skip_tokenizer_init if backend == "pytorch" and iteration_log is not None: kwargs["enable_iter_perf_stats"] = True @@ -401,10 +463,16 @@ def throughput_command( else: llm = LLM(**kwargs) - sampling_params = SamplingParams(end_id=eos_id, - pad_id=eos_id, - n=beam_width, - use_beam_search=beam_width > 1) + sampler_args = { + "end_id": eos_id, + "pad_id": eos_id, + "n": beam_width, + "use_beam_search": beam_width > 1 + } + sampler_args = update_sampler_args_with_extra_options( + sampler_args, params.pop("sampler_options")) + sampling_params = SamplingParams(**sampler_args) + post_proc_params = None # No detokenization # Perform warmup if requested. @@ -461,10 +529,10 @@ def throughput_command( report_utility.report_statistics() except KeyboardInterrupt: logger.info("Keyboard interrupt, exiting benchmark...") - sys.exit(130) - except Exception as e: - logger.error(f"Error during benchmarking: {e}") - sys.exit(-1) + except Exception: + import traceback + logger.error(f"Error during benchmarking:\n{traceback.format_exc()}") + sys.exit(1) finally: if llm is not None: llm.shutdown() diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index bc72b5e146..ff3cd933ce 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -12,6 +12,7 @@ from tensorrt_llm._torch.pyexecutor.model_engine import \ validate_and_set_kv_cache_quant from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings, get_model_config) +from tensorrt_llm.bench.build.dataclasses import NemotronHybridConfig from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata, InferenceRequest) from tensorrt_llm.logger import logger @@ -88,6 +89,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, enable_chunked_prefill = params.get("enable_chunked_prefill", False) kv_cache_dtype = "auto" + mamba_ssm_cache_dtype = params.get("mamba_ssm_cache_dtype", "auto") kv_cache_config = {} if extra_llm_api_options: with open(extra_llm_api_options, 'r') as f: @@ -96,6 +98,8 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, "dtype": "auto", }) kv_cache_dtype = kv_cache_config.get("dtype", "auto") + mamba_ssm_cache_dtype = kv_cache_config.get("mamba_ssm_cache_dtype", + mamba_ssm_cache_dtype) enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill", enable_chunked_prefill) @@ -115,6 +119,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, else: model_config = get_model_config(model, model_path) + if isinstance(model_config, NemotronHybridConfig): + model_config.set_mamba_ssm_cache_dtype(mamba_ssm_cache_dtype) + from tensorrt_llm._torch.model_config import ModelConfig model = model_path or model tllm_model_config = ModelConfig.from_pretrained(model, @@ -161,6 +168,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, } kv_cache_config["dtype"] = kv_cache_dtype + kv_cache_config["mamba_ssm_cache_dtype"] = mamba_ssm_cache_dtype pyt_options = { "cuda_graph_config": cuda_graph_config, @@ -191,3 +199,39 @@ def generate_warmup_dataset(requests, steps) -> List[InferenceRequest]: warm_up_dataset = choices(requests, k=steps) shuffle(warm_up_dataset) return warm_up_dataset + + +def update_sampler_args_with_extra_options(sampler_args: Dict, + sampler_options: str) -> Dict: + """Update sampler arguments with options from a YAML file. + + Args: + sampler_args: Base sampler arguments dictionary. + sampler_options: Path to YAML file containing additional options. + + Returns: + Dict: Merged sampler arguments. + + Raises: + FileNotFoundError: If the YAML file doesn't exist. + yaml.YAMLError: If the YAML file is malformed. + TypeError: If the YAML content is not a dictionary. + """ + if sampler_options is not None: + try: + with open(sampler_options, 'r') as f: + sampler_options_dict = yaml.safe_load(f) + except FileNotFoundError: + raise FileNotFoundError( + f"Sampler options file not found: {sampler_options}") + except yaml.YAMLError as e: + raise yaml.YAMLError( + f"Invalid YAML in sampler options file {sampler_options}: {e}") + + if not isinstance(sampler_options_dict, dict): + raise TypeError( + f"Sampler options file {sampler_options} must contain a dictionary, " + f"got {type(sampler_options_dict)}") + + sampler_args = sampler_args | sampler_options_dict + return sampler_args diff --git a/tensorrt_llm/bench/build/dataclasses.py b/tensorrt_llm/bench/build/dataclasses.py index 93377a5779..9df0c915ff 100755 --- a/tensorrt_llm/bench/build/dataclasses.py +++ b/tensorrt_llm/bench/build/dataclasses.py @@ -223,6 +223,7 @@ class NemotronHybridConfig(ModelConfig): mamba_head_dim: int d_inner: Optional[int] = Field(default=None) num_mamba_layers: Optional[int] = Field(default=None) + mamba_ssm_cache_dtype: Optional[str] = Field(default="auto") @model_validator(mode="after") def set_values_if_none(self): @@ -248,3 +249,6 @@ class NemotronHybridConfig(ModelConfig): def cache_memory_fraction(self, cache_memory_fraction): # Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size return cache_memory_fraction**2 + + def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str): + self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 93904ff3e2..5815e25af1 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -1,5 +1,8 @@ from typing import Tuple +import torch + +from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.llmapi.llm_utils import QuantConfig from tensorrt_llm.logger import logger from tensorrt_llm.quantization.mode import QuantAlgo @@ -77,8 +80,16 @@ def calc_engine_setting( target_seq_len = target_input_len + target_output_len cache_memory = available_memory * model_config.cache_memory_fraction( kv_cache_gpu_mem_fraction) + + bytes_per_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT) + if isinstance(model_config, NemotronHybridConfig): + mamba_ssm_cache_dtype = model_config.mamba_ssm_cache_dtype + if mamba_ssm_cache_dtype != "auto": + if str_dtype_to_torch(mamba_ssm_cache_dtype) == torch.float32: + bytes_per_elem = 4.0 + gb_per_extra_cache = model_config.extra_model_cache_in_gb( - BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT), target_seq_len) + bytes_per_elem, target_seq_len) kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len + gb_per_extra_cache) extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests diff --git a/tensorrt_llm/bench/dataclasses/reporting.py b/tensorrt_llm/bench/dataclasses/reporting.py index a4154ee43c..acf7f60bcb 100755 --- a/tensorrt_llm/bench/dataclasses/reporting.py +++ b/tensorrt_llm/bench/dataclasses/reporting.py @@ -318,9 +318,9 @@ class ReportUtility: "backend": "Pytorch", "dtype": - torch_dtype_to_str( - model_config.pretrained_config.torch_dtype - or model_config.pretrained_config.text_config.torch_dtype), + torch_dtype_to_str(model_config.torch_dtype + or model_config.pretrained_config. + get_text_config().torch_dtype), "kv_cache_dtype": model_config.quant_config.kv_cache_quant_algo, "quantization": diff --git a/tensorrt_llm/bench/utils/data.py b/tensorrt_llm/bench/utils/data.py index 6469d08739..080655c727 100644 --- a/tensorrt_llm/bench/utils/data.py +++ b/tensorrt_llm/bench/utils/data.py @@ -41,6 +41,8 @@ def create_dataset_from_stream( model_dir: str = None, model_type: str = None, modality: str = None, + image_data_format: str = "pt", + data_device: str = "cpu", max_input_seq_len_for_multimodal: int = 4096, ) -> Tuple[DatasetMetadata, List[InferenceRequest]]: """Generate metadata and a list of requests to drive benchmarking. @@ -130,7 +132,9 @@ def create_dataset_from_stream( model_type=model_type, modality=modality, prompts=prompts, - media=media_paths) # list of dicts + media=media_paths, # list of dicts + image_data_format=image_data_format, + device=data_device) all_isl = [] all_seq_len = [] diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 11d528a853..32a66b160d 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -36,7 +36,7 @@ from .bindings import KVCacheType from .functional import PositionEmbeddingType from .graph_rewriting import optimize from .logger import logger -from .lora_manager import LoraConfig +from .lora_helper import LoraConfig from .models import PretrainedConfig, PretrainedModel from .models.modeling_utils import SpeculativeDecodingMode, optimize_model from .network import Network, net_guard diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index 9374883a9c..7cca71c879 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -31,7 +31,8 @@ from tensorrt_llm.auto_parallel.cluster_info import cluster_infos from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.builder import BuildConfig, Engine, build from tensorrt_llm.logger import logger, severity_map -from tensorrt_llm.lora_manager import LoraConfig, LoraManager +from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.models import MODEL_MAP, PretrainedConfig from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode from tensorrt_llm.plugin import PluginConfig, add_plugin_argument diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 4f26be6579..f949fda0d9 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -81,6 +81,7 @@ def get_llm_args(model: str, moe_expert_parallel_size: Optional[int] = None, gpus_per_node: Optional[int] = None, free_gpu_memory_fraction: Optional[float] = None, + mamba_ssm_cache_dtype: str = "auto", num_postprocess_workers: int = 0, trust_remote_code: bool = False, reasoning_parser: Optional[str] = None, @@ -96,7 +97,8 @@ def get_llm_args(model: str, max_beam_width=max_beam_width, max_seq_len=max_seq_len) kv_cache_config = KvCacheConfig( - free_gpu_memory_fraction=free_gpu_memory_fraction) + free_gpu_memory_fraction=free_gpu_memory_fraction, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) dynamic_batch_config = DynamicBatchConfig( enable_batch_size_tuning=True, @@ -237,6 +239,12 @@ def launch_server(host: str, default=0.9, help="Free GPU memory fraction reserved for KV Cache, " "after allocating model weights and buffers.") +@click.option( + "--mamba_ssm_cache_dtype", + type=click.Choice(["auto", "float16", "bfloat16", "float32"]), + default="auto", + help="Data type for Mamba SSM cache. If 'auto', inferred from model config." +) @click.option( "--num_postprocess_workers", type=int, @@ -277,16 +285,17 @@ def launch_server(host: str, help= "Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache." ) -def serve( - model: str, tokenizer: Optional[str], host: str, port: int, - log_level: str, backend: str, max_beam_width: int, max_batch_size: int, - max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int, - ep_size: Optional[int], cluster_size: Optional[int], - gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float, - num_postprocess_workers: int, trust_remote_code: bool, - extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], - metadata_server_config_file: Optional[str], server_role: Optional[str], - fail_fast_on_attention_window_too_large: bool): +def serve(model: str, tokenizer: Optional[str], host: str, port: int, + log_level: str, backend: str, max_beam_width: int, + max_batch_size: int, max_num_tokens: int, max_seq_len: int, + tp_size: int, pp_size: int, ep_size: Optional[int], + cluster_size: Optional[int], gpus_per_node: Optional[int], + kv_cache_free_gpu_memory_fraction: float, mamba_ssm_cache_dtype: str, + num_postprocess_workers: int, trust_remote_code: bool, + extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], + metadata_server_config_file: Optional[str], + server_role: Optional[str], + fail_fast_on_attention_window_too_large: bool): """Running an OpenAI API compatible server MODEL: model name | HF checkpoint path | TensorRT engine path @@ -307,6 +316,7 @@ def serve( moe_cluster_parallel_size=cluster_size, gpus_per_node=gpus_per_node, free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype, num_postprocess_workers=num_postprocess_workers, trust_remote_code=trust_remote_code, reasoning_parser=reasoning_parser, diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index 16cfb7d384..4dfaa5bca4 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -1,6 +1,11 @@ from dataclasses import dataclass from typing import List, Optional +# isort: off +# needed before trying to import bindings to load tensorrt_libs +import tensorrt as trt # noqa +# isort: on + from tensorrt_llm.bindings import executor as tllme diff --git a/tensorrt_llm/evaluate/interface.py b/tensorrt_llm/evaluate/interface.py index 5e88d3f309..3a038919ce 100644 --- a/tensorrt_llm/evaluate/interface.py +++ b/tensorrt_llm/evaluate/interface.py @@ -33,11 +33,13 @@ class Evaluator(ABC): def __init__(self, random_seed: int = 0, apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, system_prompt: Optional[str] = None): random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) self.apply_chat_template = apply_chat_template + self.fewshot_as_multiturn = fewshot_as_multiturn self.system_prompt = system_prompt @abstractmethod diff --git a/tensorrt_llm/evaluate/json_mode_eval.py b/tensorrt_llm/evaluate/json_mode_eval.py index 1854488bce..37360754e5 100644 --- a/tensorrt_llm/evaluate/json_mode_eval.py +++ b/tensorrt_llm/evaluate/json_mode_eval.py @@ -18,6 +18,7 @@ from typing import Iterable, List, Optional, Union import click import datasets +import jsonschema import numpy as np from .. import LLM as PyTorchLLM @@ -65,23 +66,30 @@ class JsonModeEval(Evaluator): sampling_args = { "guided_decoding": GuidedDecodingParams(json=schema) } - yield sample["prompt"], sampling_args, sample["completion"] + yield sample["prompt"], sampling_args, sample["completion"], sample[ + "schema"] - def compute_score(self, outputs: List[RequestOutput], - references: List[str]) -> float: - all_corrections = [] - for output, ref in zip(outputs, references): + def compute_score(self, outputs: List[RequestOutput], references: List[str], + schemas: List[str]) -> float: + all_corrections, all_grammar_corrections = [], [] + for output, ref, schema in zip(outputs, references, schemas): try: output_json = json.loads(output.outputs[0].text) - except json.JSONDecodeError: + jsonschema.validate(output_json, json.loads(schema)) + except (json.JSONDecodeError, jsonschema.ValidationError): all_corrections.append(False) + all_grammar_corrections.append(False) continue - ref_json = json.loads(ref) - all_corrections.append(output_json == ref_json) + all_corrections.append(output_json == json.loads(ref)) + all_grammar_corrections.append(True) acc = np.mean(all_corrections) * 100 logger.info( f"JSON Mode Eval accuracy: {acc:.2f} ({len(all_corrections)})") + grammar_acc = np.mean(all_grammar_corrections) * 100 + logger.info( + f"JSON Mode Eval grammar accuracy: {grammar_acc:.2f} ({len(all_grammar_corrections)})" + ) return acc @click.command("json_mode_eval") diff --git a/tensorrt_llm/evaluate/lm_eval.py b/tensorrt_llm/evaluate/lm_eval.py index bdddbcbb73..920299b103 100644 --- a/tensorrt_llm/evaluate/lm_eval.py +++ b/tensorrt_llm/evaluate/lm_eval.py @@ -25,6 +25,7 @@ import tensorrt_llm.profiler as profiler try: from lm_eval.api.model import TemplateLM + from lm_eval.tasks import TaskManager except ImportError: TemplateLM = object @@ -132,6 +133,7 @@ class LmEvalEvaluator(Evaluator): num_samples: Optional[int] = None, random_seed: int = 0, apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, system_prompt: Optional[str] = None): try: import lm_eval @@ -140,14 +142,16 @@ class LmEvalEvaluator(Evaluator): f"Evaluation task {self.__class__.__name__} requires `lm_eval`. " "Please install the package first, e.g., `pip install lm_eval`." ) from e + import lm_eval.tasks super().__init__(random_seed=random_seed, apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, system_prompt=system_prompt) self.task_name = task_name self.dataset_path = dataset_path self.num_samples = num_samples - task_manager = lm_eval.tasks.TaskManager( + task_manager = TaskManager( include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks") with self._patch_lm_eval(): self.task_dict = lm_eval.tasks.get_task_dict( @@ -189,14 +193,16 @@ class LmEvalEvaluator(Evaluator): def evaluate(self, llm: Union[LLM, PyTorchLLM], sampling_params: Optional[SamplingParams] = None, - streaming: bool = False) -> float: + streaming: bool = False, + scores_filter: str = None) -> float: import lm_eval - results = lm_eval.evaluate(lm=LmEvalWrapper(llm, sampling_params, - streaming), - task_dict=self.task_dict, - limit=self.num_samples, - apply_chat_template=self.apply_chat_template, - system_instruction=self.system_prompt) + results = lm_eval.evaluate( + lm=LmEvalWrapper(llm, sampling_params, streaming), + task_dict=self.task_dict, + limit=self.num_samples, + apply_chat_template=self.apply_chat_template, + fewshot_as_multiturn=self.fewshot_as_multiturn, + system_instruction=self.system_prompt) # Normalize scores to range 0~100 scores = results["results"][self.task_name] for metric in scores.keys(): @@ -205,12 +211,17 @@ class LmEvalEvaluator(Evaluator): logger.info( f"lm-eval {self.task_name} results (scores normalized to range 0~100):\n{lm_eval.utils.make_table(results)}" ) - - average_acc = np.mean( - [acc for m, acc in scores.items() if "_stderr" not in m]) - logger.info( - f"lm-eval {self.task_name} average accuracy: {average_acc:.2f}") - return average_acc + if scores_filter is not None: + result_acc = results["results"][self.task_name][scores_filter] + logger.info( + f"lm-eval {self.task_name} {scores_filter} accuracy: {result_acc:.2f}" + ) + else: + result_acc = np.mean( + [acc for m, acc in scores.items() if "_stderr" not in m]) + logger.info( + f"lm-eval {self.task_name} average accuracy: {result_acc:.2f}") + return result_acc @classmethod def command_harness(cls, ctx, **kwargs): @@ -220,6 +231,8 @@ class LmEvalEvaluator(Evaluator): random_seed=kwargs.pop("random_seed", 0), apply_chat_template=kwargs.pop("apply_chat_template", False), + fewshot_as_multiturn=kwargs.pop("fewshot_as_multiturn", + False), system_prompt=kwargs.pop("system_prompt", None)) sampling_params = SamplingParams( max_tokens=kwargs.pop("max_output_length"), @@ -253,6 +266,10 @@ class GSM8K(LmEvalEvaluator): is_flag=True, default=False, help="Whether to apply chat template.") + @click.option("--fewshot_as_multiturn", + is_flag=True, + default=False, + help="Apply fewshot as multiturn.") @click.option("--system_prompt", type=str, default=None, @@ -268,6 +285,10 @@ class GSM8K(LmEvalEvaluator): @click.pass_context @staticmethod def command(ctx, **kwargs) -> None: + if kwargs.get("fewshot_as_multiturn", False): + assert kwargs.get( + "apply_chat_template", False + ), "apply_chat_template must be True when fewshot_as_multiturn is True" GSM8K.command_harness(ctx, **kwargs) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 9ce4ad0d85..14c8eeb389 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -15,7 +15,7 @@ import torch from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.logger import logger, set_level -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_helper import LoraConfig from .._utils import mpi_world_size from ..bindings import executor as tllm diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 2e5a3cd296..7dff328918 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -3,7 +3,7 @@ import traceback from collections import deque from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, - Optional) + Optional, Union) import zmq import zmq.asyncio @@ -18,7 +18,7 @@ from .utils import is_llm_response if TYPE_CHECKING: from .result import (DetokenizedGenerationResultBase, GenerationResult, - GenerationResultBase) + GenerationResultBase, ResponseWrapper) __all__ = [ "PostprocWorker", @@ -57,7 +57,7 @@ class PostprocWorker: @dataclass class Input: - rsp: "tllm.Response" + rsp: Union["tllm.Response", "ResponseWrapper"] # The information necessary for creating a GenerationResult in the first Input for each request sampling_params: Optional[SamplingParams] = None @@ -69,6 +69,7 @@ class PostprocWorker: res: Any is_final: bool error: str = "" + metrics: Optional[dict[str, float]] = None def __init__( self, @@ -118,7 +119,9 @@ class PostprocWorker: streaming=inp.streaming, tokenizer=tokenizer) - async def _handle_input(self, input: "PostprocWorker.Input") -> Any: + async def _handle_input( + self, input: Union["PostprocWorker.Input", "ResponseWrapper"] + ) -> [Any, Optional[dict[str, float]]]: ''' Handle a single response from await_response worker. ''' if input.rsp.result.context_logits is not None or \ input.rsp.result.generation_logits is not None: @@ -139,6 +142,7 @@ class PostprocWorker: record._handle_response(input.rsp) # inplace # Left the result_handler determine the final output dtype. # NOTE: This will change the CompletionOutput._postprocess_result + metrics_dict = record.metrics_dict if postproc_params := record.postproc_params: result_handler, args = postproc_params.post_processor, postproc_params.postproc_args args.tokenizer = self._tokenizer @@ -150,7 +154,7 @@ class PostprocWorker: # TODO: Keep only the diff token_ids and text in streaming mode when # result_handler is not set - return out + return out, metrics_dict async def _batched_put(self): ''' Batched IPC send. ''' @@ -173,8 +177,12 @@ class PostprocWorker: client_id = inp.rsp.client_id is_final = inp.rsp.result.is_final if is_llm_response( inp.rsp) else True - res = await self._handle_input(inp) - batch.append(PostprocWorker.Output(client_id, res, is_final)) + res, metrics = await self._handle_input(inp) + batch.append( + PostprocWorker.Output(client_id=client_id, + res=res, + is_final=is_final, + metrics=metrics)) if is_final: self._records.pop(client_id) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 0408a6c757..2566a699aa 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -15,6 +15,7 @@ from ..bindings import executor as tllm from ..disaggregated_params import DisaggregatedParams from ..llmapi.tracer import global_tracer from ..llmapi.utils import AsyncQueue +from ..metrics import MetricNames, MetricsCollector, RequestEventTiming from ..sampling_params import LogprobParams, SamplingParams from .utils import ErrorResponse, has_event_loop, is_llm_response @@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple): class ResponseWrapper: - """Wrapper of runtime response with optional outputs computed post runtime. + """ + 1. Wrapper of runtime response with optional outputs computed post runtime. + 2. A workaround to pass around RequestPerfMetrics. """ def __init__(self, response: Union["PostprocWorker.Output", tllm.Response], - logprobs: Optional[LogProbsResult] = None): + logprobs: Optional[LogProbsResult] = None, + request_perf_metrics: Optional[dict[str, float]] = None): self._response = response self.logprobs = logprobs + self.request_perf_metrics = request_perf_metrics @property def _is_llm_response(self): @@ -68,6 +73,14 @@ class ResponseWrapper: response = object.__getattribute__(self, '_response') return getattr(response, name) + def __getstate__(self): + return (self._response, self.logprobs, self.request_perf_metrics) + + def __setstate__(self, state): + self._response = state[0] + self.logprobs = state[1] + self.request_perf_metrics = state[2] + @dataclass(slots=True) class CompletionOutput: @@ -146,6 +159,7 @@ class GenerationResultBase: self.disaggregated_params = None self.decoding_iter = 0 self._done = False + self.metrics_dict = {} if has_event_loop(): self.aqueue = AsyncQueue() @@ -201,7 +215,9 @@ class GenerationResultBase: finish_reasons, response_tensors, sequence_index, - logprobs_result=None): + logprobs_result=None, + req_perf_metrics_dict: Optional[dict[str, + float]] = None): """ Handle a single sequence in the response. """ seq_idx = sequence_index @@ -271,6 +287,7 @@ class GenerationResultBase: else: raise ValueError( f"Unknown finish reason: {finish_reasons[src_idx]}") + self.record_stats(output, req_perf_metrics_dict) @nvtx_range_debug("handle_response", color="red", @@ -278,7 +295,9 @@ class GenerationResultBase: def _handle_response(self, response: Union["PostprocWorker.Output", tllm.Response, ResponseWrapper, ErrorResponse]): + req_perf_metrics_dict = None if isinstance(response, ResponseWrapper): + req_perf_metrics_dict = response.request_perf_metrics logprobs_result = response.logprobs response = response._response else: @@ -291,6 +310,8 @@ class GenerationResultBase: self._outputs[0] = response.res else: self._outputs[0]._postprocess_result = response.res + if response.metrics: + self.metrics_dict = response.metrics if response.error: if self._background_error_handler is not None and ( @@ -303,7 +324,8 @@ class GenerationResultBase: handler(response.error_msg) response_result = response.result - if hasattr(response_result, "_result"): + if hasattr(response_result, "_result") and isinstance( + response_result._result, bytes): response_result.deserialize() self._done = response_result.is_final @@ -322,11 +344,12 @@ class GenerationResultBase: if self.sampling_params.use_beam_search: for beam_idx, _ in enumerate(response_result.output_token_ids): self._handle_sequence(finish_reasons, response_result, - beam_idx, logprobs_result) + beam_idx, logprobs_result, + req_perf_metrics_dict) else: self._handle_sequence(finish_reasons, response_result, response_result.sequence_index, - logprobs_result) + logprobs_result, req_perf_metrics_dict) if response_result.context_logits is not None: self._context_logits = response_result.context_logits @@ -342,6 +365,29 @@ class GenerationResultBase: else: raise ValueError(f"Unknown response type: {response}") + def record_stats(self, + output: CompletionOutput, + stats: Optional[dict[str, float]] = None) -> None: + """Record the stats of the generation result. + + Args: + output (CompletionOutput): The output of the generation result. + stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None. + """ + if not stats: + return + metrics_stats = {} + if output.finish_reason: + metrics_stats.update({ + MetricsCollector.labelname_finish_reason: + output.finish_reason + }) + processed_metrics_stat = _process_req_perf_metrics( + stats, len(output.token_ids), self.sampling_params.n > 1) + if processed_metrics_stat: + metrics_stats.update(processed_metrics_stat) + self.metrics_dict = metrics_stats + class DetokenizedGenerationResultBase(GenerationResultBase): ''' The base class for the generation result with detokenization support. ''' @@ -688,3 +734,30 @@ def compute_logprobs( return LogProbsResult(prompt=prompt_logprobs, generation=generation_logprobs) + + +def _process_req_perf_metrics( + req_perf_metrics_dict: Optional[dict[str, float]], + output_length: int, + is_multiple_response: bool = False) -> dict[MetricNames, float]: + stat = {} + if not req_perf_metrics_dict: + return stat + ttft = req_perf_metrics_dict.get(RequestEventTiming.FIRST_TOKEN_TIME, 0) - \ + req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0) + e2e = req_perf_metrics_dict.get(RequestEventTiming.LAST_TOKEN_TIME, 0) - \ + req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0) + request_queue_time = req_perf_metrics_dict.get(RequestEventTiming.FIRST_SCHEDULED_TIME, 0) - \ + req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0) + stat = { + MetricNames.TTFT: ttft, + MetricNames.E2E: e2e, + MetricNames.REQUEST_QUEUE_TIME: request_queue_time + } + if output_length > 1 and not is_multiple_response: + tpot = (req_perf_metrics_dict.get( + RequestEventTiming.LAST_TOKEN_TIME, 0) - req_perf_metrics_dict.get( + RequestEventTiming.FIRST_TOKEN_TIME, 0)) / (output_length - 1) + stat.update({MetricNames.TPOT: tpot}) + stat = dict(filter(lambda item: item[1] > 0, stat.items())) + return stat diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 33ed146c9c..6d5ec9c1d7 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -24,7 +24,9 @@ from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, clear_sched_affinity, print_colored_debug, print_traceback_on_error) -from ..lora_manager import LoraConfig, LoraManager +from ..lora_helper import LoraConfig +from ..lora_manager import LoraManager +from ..metrics import RequestEventTiming from ..prompt_adapter_manager import PromptAdapterManager from ..runtime import ModelConfig from ..runtime.model_runner import _engine_config_to_model_config @@ -372,6 +374,7 @@ class GenerationExecutorWorker(GenerationExecutor): def _enqueue_request(self, request: GenerationRequest) -> int: assert request.id is not None + py_lora_path = None if self._lora_manager is not None and request.lora_request is not None: adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache( request.lora_request.adapter_id) @@ -381,8 +384,8 @@ class GenerationExecutorWorker(GenerationExecutor): task_id=request.lora_request.adapter_id, weights=self._lora_manager.cpp_lora_weights[uid] if not adapter_in_cache else None, - config=self._lora_manager.cpp_lora_config[uid] - if not adapter_in_cache else None) + config=self._lora_manager.cpp_lora_config[uid]) + py_lora_path = request.lora_request.lora_path else: lora_config = None @@ -485,7 +488,7 @@ class GenerationExecutorWorker(GenerationExecutor): lora_config=lora_config, prompt_tuning_config=prompt_tuning_config, multimodal_input=multimodal_input, - #NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. + # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. multimodal_embedding=None, mrope_config=None, logits_post_processor_name=( @@ -497,20 +500,12 @@ class GenerationExecutorWorker(GenerationExecutor): kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, type=request_type) + executor_request.py_lora_path = py_lora_path if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: - # Convert back to tensor, as opposite to `to_handle` in `llm.generate_async` - # for values with non-selected keys, it's no-op - request.multimodal_params.to_tensor( - "multimodal_data", key="multimodal_embedding") - embedding = request.multimodal_params.multimodal_data.get( - "multimodal_embedding") - if embedding is not None and embedding.is_cuda: - # make sure the embedding resides on the local device - request.multimodal_params.multimodal_data[ - "multimodal_embedding"] = embedding.to("cuda") - + # NOTE: Deserialize SharedTensor handle to actual tensor + request.multimodal_params.to_tensor("multimodal_data") executor_request.py_multimodal_data = request.multimodal_params.multimodal_data if self._is_pytorch_backend and request.sampling_params.logits_processor: @@ -897,10 +892,8 @@ class AwaitResponseHelper: assert response is not None queue = self.worker.return_queue(response.client_id) - logprobs_result = _get_logprobs(self.worker, response, + response = _maybe_wrap_response(self.worker, response, self.worker._is_pytorch_backend) - if logprobs_result: - response = ResponseWrapper(response, logprobs_result) # For AsyncQueue.sync_q, we will batch the events to avoid too many # event notifications, thus put without wait here. @@ -930,7 +923,9 @@ class AwaitResponseHelper: for response in responses: - if self.worker._has_background_error(): + if isinstance(response, ErrorResponse): + pass # send ErrorResponse directly + elif self.worker._has_background_error(): response = self.worker._create_error_response(response) elif response.has_error(): # Convert to ErrorResponse, because tllm.Response cannot be @@ -938,10 +933,8 @@ class AwaitResponseHelper: response = ErrorResponse(response.client_id, response.error_msg, response.request_id) else: - logprobs_result = _get_logprobs(self.worker, response, + response = _maybe_wrap_response(self.worker, response, self.worker._is_pytorch_backend) - if logprobs_result: - response = ResponseWrapper(response, logprobs_result) _send_rsp(self.worker, response, @@ -1049,3 +1042,41 @@ def _send_rsp( worker._pop_result(response.client_id) else: raise ValueError(f"Unknown response type: {response}") + + +def _get_metrics_dict( + response: tllm.Response) -> dict[RequestEventTiming, float]: + req_perf_metrics, metrics_dict = None, {} + res = response.result + if res: + if hasattr(res, '_result'): + if result := res.get_result(): + req_perf_metrics = result.request_perf_metrics + else: + req_perf_metrics = res.request_perf_metrics + if req_perf_metrics and req_perf_metrics.timing_metrics: + metrics_dict = { + RequestEventTiming.ARRIVAL_TIME: + req_perf_metrics.timing_metrics.arrival_time.total_seconds(), + RequestEventTiming.FIRST_TOKEN_TIME: + req_perf_metrics.timing_metrics.first_token_time.total_seconds( + ), + RequestEventTiming.FIRST_SCHEDULED_TIME: + req_perf_metrics.timing_metrics.first_scheduled_time. + total_seconds(), + RequestEventTiming.LAST_TOKEN_TIME: + req_perf_metrics.timing_metrics.last_token_time.total_seconds() + } + return metrics_dict + + +def _maybe_wrap_response( + worker, + response: tllm.Response, + is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]: + + logprobs_result = _get_logprobs(worker, response, is_pytorch_backend) + req_perf_metrics = _get_metrics_dict(response) + if logprobs_result or req_perf_metrics: + response = ResponseWrapper(response, logprobs_result, req_perf_metrics) + return response diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 06880bc430..2492eb6a61 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -30,9 +30,9 @@ from . import graph_rewriting as gw from ._common import default_net, default_trtnet, precision from ._utils import (QuantModeWrapper, bf16_array, bool_array, dim_resolve_negative, dim_to_trt_axes, dims_array, - fp16_array, fp32_array, int32_array, int64_array, - np_dtype_to_trt, str_dtype_to_trt, trt_dtype_to_np, - trt_dtype_to_str) + fp16_array, fp32_array, get_sm_version, int32_array, + int64_array, np_dtype_to_trt, str_dtype_to_trt, + trt_dtype_to_np, trt_dtype_to_str) from .network import PluginInfo, set_np_weight, set_plugin_info from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper from .quantization import QuantMode @@ -3882,6 +3882,7 @@ class AllReduceStrategy(IntEnum): TWOSHOT = 5 LOWPRECISION = 6 MNNVL = 7 + NCCL_SYMMETRIC = 8 class AllReduceFusionOp(IntEnum): @@ -4733,6 +4734,15 @@ class RopeEmbeddingUtils: inv_freq = 1.0 / (theta**(np.arange(0, dim, 2) / dim)).astype(dtype) inv_freq = RopeEmbeddingUtils.apply_llama3_scaling( inv_freq, rope_scaling_config) + elif scale_type == RotaryScalingType.dynamic: + # Make sure scaling_alpha exists in rope_scaling + # Ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8/blob/main/modeling_hunyuan.py#L346 + assert rope_scaling_config[ + "alpha"] is not None, "rope_scaling_config.alpha must be provided." + scaling_alpha = rope_scaling_config["alpha"] + adjusted_base = theta * (scaling_alpha**(dim / (dim - 2))) + inv_freq = 1.0 / (adjusted_base**( + np.arange(0, dim, 2, dtype=dtype) / dim)).astype(dtype) else: inv_freq = scale / (theta **(np.arange(0, dim, 2) / dim)).astype(dtype) @@ -5718,7 +5728,8 @@ def gpt_attention( if (attention_mask is not None) or (attention_packed_mask is not None): # context fmha needs packed mask. assert attention_packed_mask is not None - mask_type = AttentionMaskType.custom_mask + if get_sm_version() < 100: + mask_type = AttentionMaskType.custom_mask mask_type_filed = trt.PluginField("mask_type", np.array([int(mask_type)], np.int32), @@ -5843,7 +5854,7 @@ def gpt_attention( if attention_mask is not None and mask_type == AttentionMaskType.custom_mask: # useFullCustomMask plug_inputs += [attention_mask] - if attention_packed_mask is not None: + if attention_packed_mask is not None and get_sm_version() < 100: # usePackedCustomMask plug_inputs += [attention_packed_mask] if use_cache: diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index a20978cab4..f7e5ce97d7 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -1,16 +1,23 @@ from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs from .multimodal import MultimodalInput from .registry import (ExtraProcessedInputs, InputProcessor, - create_input_processor, create_input_processor_with_hash, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, create_input_processor, + create_input_processor_with_hash, register_input_processor) -from .utils import (ALL_SUPPORTED_MULTIMODAL_MODELS, ConversationMessage, - MultimodalData, MultimodalDataTracker, +from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS, + ALL_SUPPORTED_MULTIMODAL_MODELS, ALL_SUPPORTED_VIDEO_MODELS, + ConversationMessage, MultimodalData, MultimodalDataTracker, add_multimodal_placeholders, async_load_audio, async_load_image, async_load_video, default_multimodal_input_loader, encode_base64_content_from_url, load_image, load_video) __all__ = [ + "ALL_SUPPORTED_MULTIMODAL_MODELS", + "ALL_SUPPORTED_IMAGE_MODELS", + "ALL_SUPPORTED_VIDEO_MODELS", + "ALL_SUPPORTED_AUDIO_MODELS", "PromptInputs", "prompt_inputs", "TextPrompt", @@ -20,7 +27,8 @@ __all__ = [ "create_input_processor_with_hash", "register_input_processor", "ExtraProcessedInputs", - "ALL_SUPPORTED_MULTIMODAL_MODELS", + "MultimodalPlaceholderMetadata", + "MultimodalPlaceholderPlacement", "ConversationMessage", "MultimodalDataTracker", "MultimodalData", diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index 9368906587..77eec0ff6a 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -159,6 +159,10 @@ class MultimodalParams: multimodal_input: Multimodal input data with hashing information. multimodal_data: Processed multimodal data containing embeddings, configurations, and modality-specific data organized by type. + multimodal_runtime: Runtime data for tracking multimodal token caching and reuse + during KV cache scenarios. Contains information about cached + tokens, multimodal token positions, and lengths for efficient + processing during inference. Structure of multimodal_data: { @@ -190,6 +194,168 @@ class MultimodalParams: if self.multimodal_data is None: self.multimodal_data = {} + def _is_shared_tensor_dict(self, obj: Any) -> bool: + """Check if an object is a shared tensor dictionary. + + Args: + obj: Object to check + + Returns: + True if the object is a shared tensor dictionary, False otherwise + """ + if not isinstance(obj, dict): + return False + + # Check for required keys that uniquely identify a shared tensor dict + required_keys = {'method_key'} + if not required_keys.issubset(obj.keys()): + return False + + # Additional validation based on method_key + method_key = obj.get('method_key') + + # Import here to avoid circular imports + from tensorrt_llm._torch.shared_tensor import \ + _SharedTensorRebuildMethodRegistry + + if method_key == _SharedTensorRebuildMethodRegistry.REBUILD_CUDA: + cuda_keys = {'tensor_size', 'storage_handle', 'storage_device'} + return cuda_keys.issubset(obj.keys()) + elif method_key == _SharedTensorRebuildMethodRegistry.REBUILD_CPU: + cpu_keys = {'tensor_size', 'storage_handle', 'manager_handle'} + return cpu_keys.issubset(obj.keys()) + + return False + + def _apply_tensor_operation( + self, input_data: Union[torch.Tensor, List, dict, None], + operation: str, **kwargs) -> Union[torch.Tensor, List, dict, None]: + """Apply tensor operations recursively to nested data structures. + + This method handles three types of operations: + - "to_handle": Convert tensors to shared tensor dictionaries + - "to_tensor": Convert shared tensor dictionaries back to tensors + - "to_device": Move tensors to specified device + + Args: + input_data: Input data structure (tensor, list, dict, or None) + operation: Operation to apply + **kwargs: Additional arguments for the operation + + Returns: + Transformed data structure + """ + # Handle None case + if input_data is None: + return None + + # Handle list case - recursively process each element + if isinstance(input_data, list): + return [ + self._apply_tensor_operation(item, operation, **kwargs) + for item in input_data + ] + + # Handle dictionary case + if isinstance(input_data, dict): + if operation == "to_tensor" and self._is_shared_tensor_dict( + input_data): + # Convert shared tensor dict back to tensor + try: + # Import here to avoid circular imports + from tensorrt_llm._torch.shared_tensor import \ + SharedTensorContainer + + return SharedTensorContainer.from_dict( + input_data).get_local_view() + except Exception as e: + raise RuntimeError( + f"Failed to restore tensor from shared tensor dict: {e}" + ) + else: + # Regular dictionary - recursively process values + return { + key: self._apply_tensor_operation(value, operation, + **kwargs) + for key, value in input_data.items() + } + + # Handle tensor case + if isinstance(input_data, torch.Tensor): + if operation == "to_handle": + try: + # Import here to avoid circular imports + from tensorrt_llm._torch.shared_tensor import \ + SharedTensorContainer + return SharedTensorContainer.from_tensor( + input_data).dump_to_dict() + except Exception as e: + raise RuntimeError( + f"Failed to convert tensor to shared tensor: {e}") + elif operation == "to_device": + device = kwargs.get('device') + if device is None: + raise ValueError( + "Device must be specified for 'to_device' operation") + + pin_memory = kwargs.get('pin_memory', False) + try: + if pin_memory and input_data.device.type == 'cpu': + return input_data.pin_memory().to(device, + non_blocking=True) + else: + return input_data.to(device, non_blocking=True) + except Exception as e: + raise RuntimeError( + f"Failed to move tensor to device {device}: {e}") + + # For any other type, return as-is + return input_data + + def to_handle(self, element: str) -> None: + """Move specified multimodal data element to shared tensor. + + Args: + element: Element to move (only "multimodal_data" is supported) + + Raises: + ValueError: If element is not "multimodal_data" + RuntimeError: If tensor conversion fails + """ + if element != "multimodal_data": + raise ValueError( + f"Unsupported element '{element}'. Only 'multimodal_data' is supported." + ) + + data = getattr(self, element) + if data is None: + return # Nothing to convert + + transformed_data = self._apply_tensor_operation(data, "to_handle") + setattr(self, element, transformed_data) + + def to_tensor(self, element: str) -> None: + """Move specified multimodal data element from shared tensor. + + Args: + element: Element to restore (only "multimodal_data" is supported) + + Raises: + ValueError: If element is not "multimodal_data" + RuntimeError: If tensor restoration fails + """ + if element != "multimodal_data": + raise ValueError( + f"Unsupported element '{element}'. Only 'multimodal_data' is supported." + ) + + data = getattr(self, element) + if data is None: + return # Nothing to restore + + restored_data = self._apply_tensor_operation(data, "to_tensor") + setattr(self, element, restored_data) + def to_device(self, element: str, device: str, @@ -197,230 +363,47 @@ class MultimodalParams: """Move specified multimodal data element to target device. Args: - element: Element to move ("multimodal_data" or "multimodal_input") + element: Element to move (only "multimodal_data" is supported) device: Target device (e.g., "cuda", "cpu") pin_memory: Whether to pin memory for faster transfers + + Raises: + ValueError: If element is not "multimodal_data" or device is invalid + RuntimeError: If device transfer fails """ - - def _to_device( - input_tensor: Union[torch.Tensor, List, dict, None], - pin_memory: bool = False, - ) -> Union[torch.Tensor, List, dict, None]: - if input_tensor is None: - return None - elif isinstance(input_tensor, list): - return [_to_device(item, pin_memory) for item in input_tensor] - elif isinstance(input_tensor, dict): - return { - key: _to_device(value, pin_memory) - for key, value in input_tensor.items() - } - elif isinstance(input_tensor, torch.Tensor): - if pin_memory and input_tensor.device.type == 'cpu': - return input_tensor.pin_memory().to(device, - non_blocking=True) - else: - return input_tensor.to(device, non_blocking=True) - else: - return input_tensor - - if element == "multimodal_data": - self.multimodal_data = _to_device(self.multimodal_data, pin_memory) - elif element == "multimodal_input": - self.multimodal_input = _to_device(self.multimodal_input, - pin_memory) - else: - print( - f"MultimodalParams: Unsupported element '{element}' to move to device. " - f"Supported elements: 'multimodal_data', 'multimodal_input'") - - def to_handle(self, element: str, key: Optional[str] = None) -> None: - """Convert multimodal data to tensor handle. - - Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries) - for efficient IPC. This function is a in-place operation. - - Args: - element: Element to convert ("multimodal_data" or "multimodal_input") - key: Specific key to convert. If None, converts all tensor values in multimodal_data. - Defaults to None. - - Example: - # Convert all tensors in multimodal_data to handles - params.to_handle("multimodal_data", key=None) - - # Convert only multimodal_embedding section tensors to handles - params.to_handle("multimodal_data", key="multimodal_embedding") - """ - # Lazy import to avoid circular dependency - from tensorrt_llm._torch.shared_tensor import SharedTensorContainer - - def _to_tensor_handle(data): - for k, v in data.items(): - if isinstance(v, torch.Tensor): - # Convert tensor to handle - handle = SharedTensorContainer.from_tensor(v).dump_to_dict() - data[k] = handle - elif isinstance(v, dict): - _to_tensor_handle(v) - elif isinstance(v, list): - for i, item in enumerate(v): - if isinstance(item, torch.Tensor): - handle = SharedTensorContainer.from_tensor( - item).dump_to_dict() - v[i] = handle - - if element == "multimodal_data": - if self.multimodal_data is None: - return - if key is None: - _to_tensor_handle(self.multimodal_data) - else: - if key not in self.multimodal_data: - return # no-op if key not found - - value = self.multimodal_data[key] - if isinstance(value, torch.Tensor): - handle = SharedTensorContainer.from_tensor( - value).dump_to_dict() - self.multimodal_data[key] = handle - elif isinstance(value, dict): - _to_tensor_handle(value) - else: - raise ValueError( - f"Unsupported value type for multimodal_data: {type(value)}" - ) - elif element == "multimodal_input": - # No-op for multimodal_input - return - else: + if element != "multimodal_data": raise ValueError( - f"Unsupported element '{element}' to convert to handle.") + f"Unsupported element '{element}'. Only 'multimodal_data' is supported." + ) - def to_tensor(self, element: str, key: Optional[str] = None) -> None: - """Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle. + data = getattr(self, element) + if data is None: + return # Nothing to move - Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects - for local computation. This function performs in-place modifications to the multimodal_data. - - Args: - element: Element to convert ("multimodal_data" or "multimodal_input") - key: Specific key to convert. If None, converts all tensor handles in multimodal_data. - Defaults to None. - - Example: - # Convert all handles back to tensors - params.to_tensor("multimodal_data", key=None) - - # Convert only multimodal_embedding section handles back to tensors - params.to_tensor("multimodal_data", key="multimodal_embedding") - """ - # Lazy import to avoid circular dependency - from tensorrt_llm._torch.shared_tensor import SharedTensorContainer - - def _to_tensor(data): - for k, v in data.items(): - if isinstance(v, dict) and 'method_key' in v: - # This is a tensor handle (dict with method_key) - try: - tensor = SharedTensorContainer.from_dict( - v).get_local_view() - data[k] = tensor - except Exception as e: - raise ValueError( - f"Failed to convert handle to tensor for key '{k}': {e}" - ) - elif isinstance(v, dict): - _to_tensor(v) - elif isinstance(v, list): - for i, item in enumerate(v): - if isinstance(item, dict) and 'method_key' in item: - try: - tensor = SharedTensorContainer.from_dict( - item).get_local_view() - v[i] = tensor - except Exception as e: - raise ValueError( - f"Failed to convert handle to tensor in list at index {i}: {e}" - ) - - if element == "multimodal_data": - if self.multimodal_data is None: - return - - if key is None: - _to_tensor(self.multimodal_data) - else: - if key not in self.multimodal_data: - return # no-op if key not found - - value = self.multimodal_data[key] - if isinstance( - value, dict - ) and 'method_key' in value: # This is a tensor handle - try: - tensor = SharedTensorContainer.from_dict( - value).get_local_view() - self.multimodal_data[key] = tensor - except Exception as e: - raise ValueError( - f"Failed to convert handle to tensor for key '{key}': {e}" - ) - elif isinstance(value, dict): - _to_tensor(value) - else: - raise ValueError( - f"Unsupported value type for multimodal_data: {type(value)}" - ) - - elif element == "multimodal_input": - # No-op for multimodal_input - return - else: - raise ValueError( - f"Unsupported element '{element}' to convert to tensor.") - - def strip_for_context(self) -> None: - """Strip multimodal data for context processing. - - Removes only mrope_position_deltas while keeping all other multimodal data - (embeddings, images, etc.) needed for context phase processing. - """ - if not (self.multimodal_data - and 'mrope_config' in self.multimodal_data): - return - - mrope_config = self.multimodal_data['mrope_config'] - if 'mrope_position_deltas' in mrope_config: - del mrope_config['mrope_position_deltas'] - - # Clean up empty mrope_config - if not mrope_config: - del self.multimodal_data['mrope_config'] + transformed_data = self._apply_tensor_operation(data, + "to_device", + device=device, + pin_memory=pin_memory) + setattr(self, element, transformed_data) def strip_for_generation(self) -> None: """Strip multimodal data for generation processing. - Keeps only mrope_position_deltas and removes all other multimodal data + Keeps only mrope_config and removes all other multimodal data (embeddings, images, etc.) as they're not needed during generation. """ if not self.multimodal_data: return - # Extract mrope_position_deltas before clearing - mrope_position_deltas = None + # Extract mrope_config before clearing + mrope_config = None if 'mrope_config' in self.multimodal_data: mrope_config = self.multimodal_data['mrope_config'] - if isinstance(mrope_config, - dict) and 'mrope_position_deltas' in mrope_config: - mrope_position_deltas = mrope_config['mrope_position_deltas'] - # Clear all data and restore only position deltas if they exist + # Clear all data and restore only mrope_config if it exists self.multimodal_data = {} - if mrope_position_deltas is not None: - self.multimodal_data['mrope_config'] = { - 'mrope_position_deltas': mrope_position_deltas - } + if mrope_config is not None: + self.multimodal_data['mrope_config'] = mrope_config def has_content(self) -> bool: """Check if this object contains any multimodal data.""" diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 010eb674a2..e75ba73908 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -1,3 +1,5 @@ +import enum +from dataclasses import dataclass, field from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type, TypeVar) @@ -10,7 +12,6 @@ from .data import TextPrompt from .multimodal import (MultimodalInput, apply_mm_hashes, default_hasher, find_mm_token_lengths, find_mm_token_positions, hexdigest_to_int32, validate_mm_inputs) -from .utils import ALL_SUPPORTED_MULTIMODAL_MODELS N = TypeVar("N", bound=Type[nn.Module]) @@ -32,6 +33,7 @@ class InputProcessor(Protocol): model_path: any model_config: any tokenizer: any + multimodal_hashing_supported: Optional[bool] = None def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams @@ -50,6 +52,7 @@ class DefaultInputProcessor(InputProcessor): self.tokenizer = tokenizer self.model_config = model_config self.model_path = model_path + self.multimodal_hashing_supported = None def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams @@ -61,24 +64,174 @@ class DefaultInputProcessor(InputProcessor): if sampling_params.truncate_prompt_tokens is not None: kwargs = dict(truncation=True, max_length=sampling_params.truncate_prompt_tokens) - + toktoken_special_tokens = { + "<|startoftext|>", + "<|endoftext|>", + "<|reserved_200000|>", + "<|reserved_200001|>", + "<|return|>", + "<|constrain|>", + "<|reserved_200004|>", + "<|channel|>", + "<|start|>", + "<|end|>", + "<|message|>", + "<|reserved_200009|>", + "<|reserved_200010|>", + "<|reserved_200011|>", + "<|call|>", + "<|reserved_200013|>", + } with nvtx_range_debug("tokenize prompt"): - token_ids = self.tokenizer.encode( - inputs["prompt"], - add_special_tokens=sampling_params.add_special_tokens, - **kwargs) + try: + token_ids = self.tokenizer.encode( + inputs["prompt"], + add_special_tokens=sampling_params.add_special_tokens, + **kwargs) + except: + # Tiktoken path + token_ids = self.tokenizer.encode( + inputs["prompt"], allowed_special=toktoken_special_tokens) if "query" in inputs: with nvtx_range_debug("tokenize query"): - query_token_ids = self.tokenizer.encode( - inputs["query"], - add_special_tokens=sampling_params.add_special_tokens, - **kwargs) + try: + query_token_ids = self.tokenizer.encode( + inputs["query"], + add_special_tokens=sampling_params.add_special_tokens, + **kwargs) + except: + # Tiktoken path + query_token_ids = self.tokenizer.encode( + inputs["query"], + allowed_special=toktoken_special_tokens) + return token_ids, {"query_token_ids": query_token_ids} return token_ids, None +class MultimodalPlaceholderPlacement(enum.Enum): + """ + The placement of the multimodal placeholder in the prompt. Valid values are: + - BEFORE_TEXT: the placeholders are placed before the text prompt. + - AFTER_TEXT: the placeholders are placed after the text prompt. + """ + INVALID = -1 + BEFORE_TEXT = 0 + AFTER_TEXT = 1 + + +@dataclass(frozen=True) +class MultimodalPlaceholderMetadata: + """ + Metadata for the multimodal placeholder. It has 3 components: + - placeholder_map: + A mapping from modality to placeholder string. + Modality can be "image", "video", "audio", etc. + - placeholder_placement: + The placement of the placeholders, e.g. before or after the text prompt. + - placeholders_separator: + The separator between the placeholders, e.g. some models use "\n" to separate the placeholders. + """ + placeholder_map: Dict[str, str] = field(default_factory=dict) + placeholder_placement: MultimodalPlaceholderPlacement = MultimodalPlaceholderPlacement.AFTER_TEXT + placeholders_separator: str = "\n" + + +class MultimodalPlaceholderRegistry: + """ + Registry for the multimodal models to keep track of the placeholder information. + """ + + def __init__(self) -> None: + self._multimodal_placeholder_by_model_type: Dict[ + str, MultimodalPlaceholderMetadata] = {} + + def __str__(self) -> str: + s = "" + for model_type, placeholder_metadata in self._multimodal_placeholder_by_model_type.items( + ): + s += "-" * 100 + "\n" + s += f"Model type: {model_type}\n" + s += f"Placeholder map: {placeholder_metadata.placeholder_map}\n" + s += f"Placeholder placement: {placeholder_metadata.placeholder_placement}\n" + s += f"Placeholders separator: \"{placeholder_metadata.placeholders_separator}\"\n" + s += "-" * 80 + "\n" + return s + + def set_placeholder_metadata( + self, model_type: str, + placeholder_metadata: MultimodalPlaceholderMetadata): + self._multimodal_placeholder_by_model_type[ + model_type] = placeholder_metadata + + def remove_placeholder_metadata(self, model_type: str): + if model_type not in self._multimodal_placeholder_by_model_type: + raise ValueError(f"Model type '{model_type}' is not registered") + del self._multimodal_placeholder_by_model_type[model_type] + + def is_valid(self, model_type: str, modality: str) -> bool: + return model_type in self._multimodal_placeholder_by_model_type and \ + modality in self._multimodal_placeholder_by_model_type[model_type].placeholder_map + + def get_placeholder_metadata( + self, model_type: str) -> MultimodalPlaceholderMetadata: + if model_type not in self._multimodal_placeholder_by_model_type: + raise ValueError( + f"Model type {model_type} is not registered in MultimodalPlaceholderRegistry" + ) + return self._multimodal_placeholder_by_model_type[model_type] + + def get_placeholder(self, model_type: str, modality: str) -> str: + if not self.is_valid(model_type, modality): + raise ValueError( + f"Model type '{model_type}' with modality '{modality}' is not registered." + ) + return self._multimodal_placeholder_by_model_type[ + model_type].placeholder_map[modality] + + def get_placeholder_placement( + self, model_type: str) -> MultimodalPlaceholderPlacement: + if model_type not in self._multimodal_placeholder_by_model_type: + raise ValueError(f"Model type '{model_type}' is not registered") + return self._multimodal_placeholder_by_model_type[ + model_type].placeholder_placement + + def get_placeholders_separator(self, model_type: str) -> str: + if model_type not in self._multimodal_placeholder_by_model_type: + raise ValueError(f"Model type '{model_type}' is not registered") + return self._multimodal_placeholder_by_model_type[ + model_type].placeholders_separator + + def get_registered_image_model_types(self) -> Tuple[str, ...]: + return ( + model_type + for model_type in self._multimodal_placeholder_by_model_type + if "image" in self. + _multimodal_placeholder_by_model_type[model_type].placeholder_map) + + def get_registered_video_model_types(self) -> Tuple[str, ...]: + return ( + model_type + for model_type in self._multimodal_placeholder_by_model_type + if "video" in self. + _multimodal_placeholder_by_model_type[model_type].placeholder_map) + + def get_registered_audio_model_types(self) -> Tuple[str, ...]: + return ( + model_type + for model_type in self._multimodal_placeholder_by_model_type + if "audio" in self. + _multimodal_placeholder_by_model_type[model_type].placeholder_map) + + def get_registered_model_types(self) -> Tuple[str, ...]: + return tuple(self._multimodal_placeholder_by_model_type.keys()) + + +MULTIMODAL_PLACEHOLDER_REGISTRY = MultimodalPlaceholderRegistry() + + class InputProcessorRegistry: def __init__(self) -> None: @@ -89,9 +242,10 @@ class InputProcessorRegistry: INPUT_PROCESSOR_REGISTRY = InputProcessorRegistry() -def register_input_processor(processor_cls: Type[InputProcessor], - model_type: str, - out_of_tree: bool = False): +def register_input_processor( + processor_cls: Type[InputProcessor], + model_type: str, + placeholder_metadata: MultimodalPlaceholderMetadata = None): """ Register an input processor to a model class. NOTE: @@ -99,17 +253,18 @@ def register_input_processor(processor_cls: Type[InputProcessor], the model type only for that. 2. If this is used for other models in the future, this logic needs to be updated e.g. adding another version of this API without the model_type. - 3. If the model is not in the tree, user needs to set out_of_tree to True - to bypass the model type check and provide their own input preparation. """ def wrapper(model_cls: N) -> N: INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type[ model_cls] = processor_cls - if not out_of_tree: - assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, \ - f"Model type {model_type} not in {ALL_SUPPORTED_MULTIMODAL_MODELS}.\n" \ - "Please see the tensorrt_llm/inputs/utils.py file for more information." + if placeholder_metadata is None: + raise ValueError( + f"A valid placeholder_metadata must be provided but got {placeholder_metadata}" + ) + + MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata( + model_type, placeholder_metadata) return model_cls @@ -163,41 +318,88 @@ def create_input_processor_with_hash( A wrapped processor that modifies prompts before processing. """ + def multimodal_hashing_process( + inputs: TextPrompt, sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """ + Process the multinmodal hashing for media tokens if possible. + """ + assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support." + mm_data = inputs['multi_modal_data'] + num_mm_tokens = find_mm_token_lengths(mm_data, input_processor) + if len(num_mm_tokens) > 0: + mm_hashes = apply_mm_hashes(mm_data, hash_lib) + prompt_token_ids, extra_processed_inputs = input_processor( + inputs, sampling_params) + start_positions = find_mm_token_positions( + input_ids=prompt_token_ids, # token sequence + num_mm_tokens= + num_mm_tokens, # list of lengths of each chunk of visual tokens + vocab_size=input_processor.model_config.vocab_size, + ) + # flatten the hashes from dict to a single list + mm_hashes = [h for hashes in mm_hashes.values() for h in hashes] + validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions, + num_mm_tokens) + mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes + ] # nested list w/ multiple int32 per hash + + extra_processed_inputs[ + "multimodal_input"] = MultimodalInput.from_components( + mm_hashes_int32, start_positions, num_mm_tokens) + return prompt_token_ids, extra_processed_inputs + return [], None + def input_processor_wrapper( inputs: TextPrompt, sampling_params: SamplingParams ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: - try: - assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support." - mm_data = inputs['multi_modal_data'] - num_mm_tokens = find_mm_token_lengths(mm_data, input_processor) - if len(num_mm_tokens) > 0: - mm_hashes = apply_mm_hashes(mm_data, hash_lib) - prompt_token_ids, extra_processed_inputs = input_processor( - inputs, sampling_params) - start_positions = find_mm_token_positions( - input_ids=prompt_token_ids, # token sequence - num_mm_tokens= - num_mm_tokens, # list of lengths of each chunk of visual tokens - vocab_size=input_processor.model_config.vocab_size, - ) - # flatten the hashes from dict to a single list - mm_hashes = [h for hashes in mm_hashes.values() for h in hashes] - validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions, - num_mm_tokens) - mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes - ] # nested list w/ multiple int32 per hash + try_multimodal_hashing = False # only used for first time + use_multimodal_hashing = False # used for subsequent calls + modalities = list(set(inputs['multi_modal_data'].keys()) + ) if 'multi_modal_data' in inputs else [] + if len(modalities) > 0: + # NOTE: tensorrt_llm/inputs/multimodal.py:find_mm_token_lengths only supports image data for now + if len(modalities) == 1 and modalities[0] == "image": + # only try multimodal hashing if the inputs only contain image data + if input_processor.multimodal_hashing_supported is not None: + use_multimodal_hashing = input_processor.multimodal_hashing_supported + else: + # we need to try the multimodal hashing for the first time to determine if it is supported + try_multimodal_hashing = True - extra_processed_inputs[ - "multimodal_input"] = MultimodalInput.from_components( - mm_hashes_int32, start_positions, num_mm_tokens) + if try_multimodal_hashing or use_multimodal_hashing: + try: + prompt_token_ids, extra_processed_inputs = multimodal_hashing_process( + inputs, sampling_params) + if try_multimodal_hashing: + # if trying for first time, set the flag to True + input_processor.multimodal_hashing_supported = True return prompt_token_ids, extra_processed_inputs - else: + except Exception as e: + import traceback + traceback.print_exc() + logger.warning(f"Multimodal hashing failed: {e}.") + if try_multimodal_hashing: + # if trying for first time, fall back to basic input processor + # and set the flag to False so that we don't try again + input_processor.multimodal_hashing_supported = False + logger.warning("Falling back to basic input processor.") + try: + return input_processor(inputs, sampling_params) + except Exception as e2: + import traceback + traceback.print_exc() + logger.warning(f"Basic input processor failed: {e}.") + raise e2 + else: + raise e + else: + try: return input_processor(inputs, sampling_params) - except Exception as e: - # Fall back to basic input processor if multimodal processing fails - logger.warning( - f"Multimodal hashing failed: {e}. Falling back to basic input processor." - ) - return input_processor(inputs, sampling_params) + except Exception as e: + import traceback + traceback.print_exc() + logger.warning(f"Basic input processor failed: {e}.") + raise e return input_processor_wrapper diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 7764436e25..3b856a2bfb 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -1,6 +1,5 @@ import asyncio import base64 -import enum import tempfile from collections import defaultdict from io import BytesIO @@ -18,16 +17,46 @@ from torchvision.transforms import ToTensor from transformers import AutoProcessor, ProcessorMixin from transformers.utils import logging +from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY, + MultimodalPlaceholderPlacement) from tensorrt_llm.llmapi.llm_utils import ModelLoader from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer logger = logging.get_logger(__name__) +def rgba_to_rgb( + image: Image.Image, + background_color: Union[tuple[int, int, int], list[int]] = (255, 255, 255) +) -> Image.Image: + """Convert an RGBA image to RGB with filled background color. + + Uses white (255, 255, 255) as the default background color because: + 1. It's the most neutral and commonly expected background for images + 2. Maintains backward compatibility with existing code + """ + if image.mode != "RGBA": + raise ValueError( + f"Expected image mode to be 'RGBA', but got '{image.mode}'") + converted = Image.new("RGB", image.size, background_color) + converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel + return converted + + +def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image: + """Convert image to specified mode with proper handling of RGBA to RGB conversion.""" + if image.mode == to_mode: + return image + elif image.mode == "RGBA" and to_mode == "RGB": + return rgba_to_rgb(image) + else: + return image.convert(to_mode) + + def _load_and_convert_image(image): image = Image.open(image) image.load() - return image.convert("RGB") + return convert_image_mode(image, "RGB") def load_base64_image(parsed_url: str) -> Image.Image: @@ -209,60 +238,25 @@ NOTE: placeholder for the model needs to be added in retrieve_multimodal_placeholder(). """ -SUPPORTED_QWEN_MODEL_GROUP = ["qwen2_vl", "qwen2_5_vl"] -SUPPORTED_GEMMA_MODEL_GROUP = ["gemma3"] -SUPPORTED_LLAMA_MODEL_GROUP = ["mllama", "llama4"] -SUPPORTED_LLAVA_IMAGE_MODEL_GROUP = ["llava_llama", "llava_next"] -SUPPORTED_LLAVA_VIDEO_MODEL_GROUP = ["llava_llama"] -SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP = ["mistral3"] -SUPPORTED_HYPERCLOVAX_MODEL_GROUP = ["hyperclovax_vlm"] -SUPPORTED_PHI_MODEL_GROUP = ["phi4mm"] - -ALL_SUPPORTED_IMAGE_MODELS = SUPPORTED_QWEN_MODEL_GROUP \ - + SUPPORTED_LLAMA_MODEL_GROUP \ - + SUPPORTED_LLAVA_IMAGE_MODEL_GROUP \ - + SUPPORTED_HYPERCLOVAX_MODEL_GROUP \ - + SUPPORTED_GEMMA_MODEL_GROUP \ - + SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP \ - + SUPPORTED_PHI_MODEL_GROUP - -ALL_SUPPORTED_VIDEO_MODELS = SUPPORTED_QWEN_MODEL_GROUP \ - + SUPPORTED_LLAVA_VIDEO_MODEL_GROUP - -ALL_SUPPORTED_AUDIO_MODELS = SUPPORTED_PHI_MODEL_GROUP - -ALL_SUPPORTED_MULTIMODAL_MODELS = list(set(ALL_SUPPORTED_IMAGE_MODELS) \ - | set(ALL_SUPPORTED_VIDEO_MODELS) \ - | set(ALL_SUPPORTED_AUDIO_MODELS)) - HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama"] PLACEHOLDER_EXCEPTIONS = ["llava_next"] -class MultimodalPlaceholderPlacement(enum.Enum): - INVALID = -1 - BEFORE_TEXT = 0 - AFTER_TEXT = 1 +# Helpers to always get the latest supported multimodal model types from the registry +def ALL_SUPPORTED_MULTIMODAL_MODELS(): + return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types() -PLACEHOLDER_PLACEMENT_MAP = { - "qwen2_vl": MultimodalPlaceholderPlacement.BEFORE_TEXT, - "qwen2_5_vl": MultimodalPlaceholderPlacement.BEFORE_TEXT, - "llava_llama": MultimodalPlaceholderPlacement.BEFORE_TEXT, - "llava_next": MultimodalPlaceholderPlacement.BEFORE_TEXT, - "llama4": MultimodalPlaceholderPlacement.BEFORE_TEXT, - "mllama": MultimodalPlaceholderPlacement.BEFORE_TEXT, - "hyperclovax_vlm": MultimodalPlaceholderPlacement.AFTER_TEXT, - "gemma3": MultimodalPlaceholderPlacement.BEFORE_TEXT, - # NOTE: for mistral3 multimodal models, it does not strictly have to be before the text. - # Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/ - # src/mistral_common/tokens/tokenizers/base.py#L326 - # However, accuracy tests show that the model generates higher quality output when the image - # precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM). - "mistral3": MultimodalPlaceholderPlacement.BEFORE_TEXT, - "phi4mm": MultimodalPlaceholderPlacement.BEFORE_TEXT, -} -assert len(PLACEHOLDER_PLACEMENT_MAP) == len(ALL_SUPPORTED_MULTIMODAL_MODELS) +def ALL_SUPPORTED_IMAGE_MODELS(): + return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types() + + +def ALL_SUPPORTED_VIDEO_MODELS(): + return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types() + + +def ALL_SUPPORTED_AUDIO_MODELS(): + return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types() def retrieve_multimodal_placeholder(model_type: str, modality: str, @@ -276,41 +270,16 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str, current_count: The number of multimodal data already added. """ - - if modality == "image": - if model_type in SUPPORTED_QWEN_MODEL_GROUP: - return "<|vision_start|><|image_pad|><|vision_end|>" - elif model_type in SUPPORTED_LLAMA_MODEL_GROUP: - return "<|image|>" - elif model_type in SUPPORTED_LLAVA_IMAGE_MODEL_GROUP: - return "<image>" - elif model_type in SUPPORTED_GEMMA_MODEL_GROUP: - return "<start_of_image>" - elif model_type in SUPPORTED_HYPERCLOVAX_MODEL_GROUP: - return '<im_end>\n<|im_start|>user (mime) \n{"type": "image/jpeg", "filename": ""}<|im_end|>\n' + \ - '<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n' + \ - '<|im_start|>image/aux\n다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 keyword와 bbox 위치입니다.' + \ - 'bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 형태입니다. 참고하여 답변하세요. {"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}' - elif model_type in SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP: - # Ref: https://github.com/mistralai/mistral-common/blob/26a6bb3a07ee0b78a3808f2797f23e1d28514b93/ - # src/mistral_common/tokens/tokenizers/base.py#L60 - return "[IMG]" - elif model_type in SUPPORTED_PHI_MODEL_GROUP: - return f"<|image_{current_count}|>" - raise TypeError( - f"For image modality, only {ALL_SUPPORTED_IMAGE_MODELS} are supported but got {model_type}" - ) - elif modality == "video": - if model_type in SUPPORTED_QWEN_MODEL_GROUP: - return "<|vision_start|><|video_pad|><|vision_end|>" - elif model_type in SUPPORTED_LLAVA_VIDEO_MODEL_GROUP: - return "<vila/video>" - raise TypeError( - f"For video modality, only {ALL_SUPPORTED_VIDEO_MODELS} are supported but got {model_type}" - ) - elif modality == "audio": - if model_type in SUPPORTED_PHI_MODEL_GROUP: - return f"<|audio_{current_count}|>" + if MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(model_type, modality): + """ + The placeholder is a string with a single placeholder for the current count. + - For example, if the placeholder is "<|image_{0}|>", and the current count is 1, + the placeholder will be "<|image_1|>". + - However, if the placeholder is "<|image|>", the current count would be ignored. + In this case, the placeholder would be "<|image|>". + """ + return MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder( + model_type, modality).format(current_count) raise TypeError(f"Unknown modality: {modality}") @@ -379,17 +348,15 @@ def add_multimodal_placeholders(model_type: str, text_prompt: str, for placeholder in mm_placeholder_counts: placeholders.extend([placeholder] * mm_placeholder_counts[placeholder]) parts = [] - match PLACEHOLDER_PLACEMENT_MAP[model_type]: + match MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder_placement(model_type): case MultimodalPlaceholderPlacement.BEFORE_TEXT: parts.extend(placeholders) parts.append(text_prompt) case MultimodalPlaceholderPlacement.AFTER_TEXT: parts.append(text_prompt) parts.extend(placeholders) - if model_type == "phi4mm": - return "".join(parts) - else: - return "\n".join(parts) + return MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholders_separator( + model_type).join(parts) def resolve_hf_chat_template( diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index ebfaa8fdea..f995b6390d 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -20,8 +20,8 @@ import tensorrt as trt import torch from .._common import default_net, precision -from .._utils import (fp32_array, get_sm_version, int32_array, is_same_dtype, - set_obj_attrs, trt_dtype_to_np, trt_dtype_to_str) +from .._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs, + trt_dtype_to_np, trt_dtype_to_str) # isort: off from ..functional import ( @@ -1755,8 +1755,6 @@ class BertAttention(Module): if default_net().plugin_config.bert_attention_plugin: # TRT plugin mode assert input_lengths is not None - assert get_sm_version() < 100 or get_sm_version() >= 120, \ - "bert_attention_plugin does not support SM100" context = bert_attention( qkv, input_lengths, diff --git a/tensorrt_llm/layers/moe.py b/tensorrt_llm/layers/moe.py index 165e17f67d..3db89d52a6 100755 --- a/tensorrt_llm/layers/moe.py +++ b/tensorrt_llm/layers/moe.py @@ -40,7 +40,8 @@ from ..module import Module, ModuleList from ..parameter import Parameter from ..plugin import TRT_LLM_PLUGIN_NAMESPACE from ..quantization import GroupwiseQuantAlgo, QuantMode -from ..quantization.functional import (postprocess_weight_only, +from ..quantization.functional import (get_weight_scale_interleave_factor, + postprocess_weight_only, preprocess_weights_for_mixed_gemm, quantize) from .linear import RowLinear @@ -54,7 +55,8 @@ activation_str_to_int_map = { "silu": 2, "swiglu": 3, "geglu": 4, - "identity": 5, + "swiglu_bias": 5, + "identity": 6, } @@ -488,11 +490,18 @@ class MOEWeightWrapper(Module): self.alpha = Parameter(shape=(experts_per_node, ), dtype=trt.float32) elif quant_mode.has_per_group_scaling(): - self.weight = Parameter(shape=(experts_per_node, in_features, - out_features // 4), - dtype=dtype) - scale_shape = (experts_per_node, in_features // group_size, - out_features) + self.weight = Parameter( + shape=(experts_per_node, in_features, + out_features // 4), # int4 <--> fp16/bf16 + dtype=dtype) + if groupwise_quant_algo & GroupwiseQuantAlgo.W4A8_ALPHA: + scale_interleave_factor = get_weight_scale_interleave_factor( + in_features, group_size) + else: + scale_interleave_factor = 1 + scale_shape = (experts_per_node, + in_features // group_size // scale_interleave_factor, + out_features * scale_interleave_factor) self.weights_scaling_factor = Parameter(shape=scale_shape, dtype=dtype) if groupwise_quant_algo & GroupwiseQuantAlgo.ZERO: @@ -692,7 +701,7 @@ class MOEWeightWrapper(Module): weights = stack_weights(tllm_key, weights) if tllm_key.endswith("weights_block_scaling_factor_interleaved"): weights = stack_weights(tllm_key, weights) - weights = torch.ops.trtllm.nvfp4_block_scale_interleave( + weights = torch.ops.trtllm.block_scale_interleave( weights.to(torch.float8_e4m3fn).view( torch.uint8).cpu().contiguous()).reshape( weights.shape).view(torch.float8_e4m3fn) @@ -768,7 +777,7 @@ class MixtureOfExperts(Module): self.use_int8_weight = use_int8_weight self.group_size = group_size - if self.use_int8_weight: + if self.use_int8_weight and self.group_size > 0: raise NotImplementedError("INT8-GPTQ is not implemented for MoE.") self.static_routing = static_routing diff --git a/tensorrt_llm/llmapi/build_cache.py b/tensorrt_llm/llmapi/build_cache.py index 86c9eb4e77..6b61d27773 100644 --- a/tensorrt_llm/llmapi/build_cache.py +++ b/tensorrt_llm/llmapi/build_cache.py @@ -12,7 +12,7 @@ from typing import Any, List, Optional import filelock import tensorrt_llm -from tensorrt_llm import BuildConfig +from tensorrt_llm.builder import BuildConfig from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored from tensorrt_llm.logger import logger diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 321ec11bd7..02298b1743 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -32,8 +32,8 @@ from ..logger import logger from ..sampling_params import SamplingParams from ..scheduling_params import SchedulingParams from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING, - TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, - PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs) + TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig, + PybindMirror, TorchLlmArgs, TrtLlmArgs) from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available @@ -403,13 +403,12 @@ class BaseLLM: 'multimodal_input'), multimodal_data=extra_processed_inputs.get( 'multimodal_data')) - # Convert to shared tensor handle to reduce IPC overhead - # for values with non-selected keys, it's no-op - multimodal_params.to_handle("multimodal_data", - key="multimodal_embedding") # Only pass it if it has content if not multimodal_params.has_content(): multimodal_params = None + else: + # Convert to shared tensor handle to reduce IPC overhead + multimodal_params.to_handle("multimodal_data") else: raise TypeError( f"The inputs must be type str or list of int, but got {type(inputs)}" @@ -548,7 +547,7 @@ class BaseLLM: if sampling_params._stream_interval is None: sampling_params._stream_interval = getattr(self.args, "stream_interval", 1) - + sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics return sampling_params def _check_arguments(self, prompt_len: int, query_len: int, @@ -1015,32 +1014,10 @@ class _TorchLLM(BaseLLM): spec_config = self.args.speculative_config max_batch_size = self._executor_config.max_batch_size - # Apply default heuristic to AutoDecodingConfig based on benchmark results - # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 - # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 - # With concurrency > 32, speculative decoding is disabled. + if spec_config is not None and spec_config.decoding_type == "AUTO": - if not self.args.disable_overlap_scheduler: - logger.info( - "Disable overlap scheduler to enable Auto speculative decoding with Ngram." - ) - # From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32. - # Therefore, we disable overlap scheduler to enable NGram speculative decoding. - self.args.disable_overlap_scheduler = True - - spec_config = NGramDecodingConfig( - max_draft_len=5 if max_batch_size <= 4 else 3, - max_matching_ngram_size=3 if max_batch_size <= 4 else 5, - is_keep_all=True, - is_use_oldest=True, - is_public_pool=True, - # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. - is_auto_heuristic=True, - ) - - logger.info( - f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" - ) + from tensorrt_llm._torch.speculative import suggest_spec_config + spec_config = suggest_spec_config(max_batch_size) update_executor_config( self._executor_config, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index b7d46ed6fa..948c4b1688 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -19,8 +19,8 @@ from pydantic import PrivateAttr, field_validator, model_validator from strenum import StrEnum from transformers import PreTrainedTokenizerBase -from tensorrt_llm.lora_manager import (LoraConfig, - get_default_trtllm_modules_to_hf_modules) +from tensorrt_llm.lora_helper import (LoraConfig, + get_default_trtllm_modules_to_hf_modules) from .._utils import mpi_rank from ..auto_parallel import AutoParallelConfig, infer_cluster_config @@ -149,7 +149,7 @@ class CudaGraphConfig(StrictBaseModel): # Add powers of 2 up to max_batch_size batch_sizes += [ - 2**i for i in range(8, math.floor(math.log(max_batch_size, 2))) + 2**i for i in range(8, math.ceil(math.log(max_batch_size, 2))) ] # Filter and sort batch sizes @@ -168,8 +168,9 @@ class MoeConfig(StrictBaseModel): Configuration for MoE. """ backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", - "VANILLA"] = Field(default='CUTLASS', - description="MoE backend to use.") + "VANILLA", + "TRITON"] = Field(default='CUTLASS', + description="MoE backend to use.") max_num_tokens: Optional[int] = Field( default=None, @@ -182,6 +183,12 @@ class MoeConfig(StrictBaseModel): description="Configuration for MoE load balancing.", json_schema_extra={"type": "Union[MoeLoadBalancerConfig, str]"}) + disable_finalize_fusion: bool = Field( + default=False, + description= + "Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2." + ) + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -341,6 +348,11 @@ class DecodingBaseConfig(StrictBaseModel): max_draft_len: Optional[int] = None speculative_model_dir: Optional[Union[str, Path]] = None + # PyTorch only. + # When specified, speculation will be disabled at batch sizes above + # this value. Otherwise, speculation will always be on. + max_concurrency: Optional[int] = None + @classmethod def from_dict(cls, data: dict): # dispatch to the correct decoding config @@ -468,9 +480,6 @@ class NGramDecodingConfig(DecodingBaseConfig): is_keep_all: bool = True is_use_oldest: bool = True is_public_pool: bool = True - # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. - # User should not set this flag. Use AutoDecodingConfig instead. - is_auto_heuristic: bool = False @classmethod def from_dict(cls, data: dict): @@ -534,13 +543,10 @@ class AutoDecodingConfig(DecodingBaseConfig): """ Configuration for auto speculative decoding. - This config is used to automatically select the best speculative decoding algorithm. + This config will automatically select a good, draft-model free + speculation algorithm with some heuristic. - According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32. - Default heuristic: - With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 - With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 - With concurrency > 32, speculative decoding is disabled. + Attributes that are inherited from the base class are ignored. """ @classmethod @@ -969,6 +975,11 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): description= "Maximum size of the event buffer. If set to 0, the event buffer will not be used." ) + attention_dp_events_gather_period_ms: int = Field( + default=5, + description= + "The period in milliseconds to gather attention DP events across ranks." + ) enable_partial_reuse: bool = Field( default=True, description= @@ -985,6 +996,14 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): dtype: str = Field(default="auto", description="The data type to use for the KV cache.") + # This is a pure python field, not a pybind field. It is only for the Pytorch backend. + mamba_ssm_cache_dtype: Literal[ + "auto", "float16", "bfloat16", "float32"] = Field( + default="auto", + description= + "The data type to use for the Mamba SSM cache. If set to 'auto', the data type will be inferred from the model config." + ) + def _to_pybind(self): return _KvCacheConfig( enable_block_reuse=self.enable_block_reuse, @@ -999,7 +1018,10 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): event_buffer_max_size=self.event_buffer_max_size, enable_partial_reuse=self.enable_partial_reuse, copy_on_partial_reuse=self.copy_on_partial_reuse, - use_uvm=self.use_uvm) + use_uvm=self.use_uvm, + attention_dp_events_gather_period_ms=self. + attention_dp_events_gather_period_ms, + ) @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) @@ -1039,7 +1061,7 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror): Configuration for the cache transceiver. """ - backend: Optional[Literal["default", "ucx", "nixl", "mpi"]] = Field( + backend: Optional[Literal["DEFAULT", "UCX", "NIXL", "MPI"]] = Field( default=None, description= "The communication backend type to use for the cache transceiver.") @@ -1303,6 +1325,10 @@ class BaseLlmArgs(StrictBaseModel): status="deprecated", ) + return_perf_metrics: bool = Field(default=False, + description="Return perf metrics.", + status="prototype") + _parallel_config: Optional[object] = PrivateAttr(default=None) _model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None) _speculative_model: Optional[str] = PrivateAttr(default=None) @@ -1942,6 +1968,13 @@ class LoadFormat(Enum): DUMMY = 1 +class SamplerType(StrEnum): + """Enum for sampler type options.""" + TRTLLMSampler = "TRTLLMSampler" + TorchSampler = "TorchSampler" + auto = "auto" + + class TorchCompileConfig(StrictBaseModel): """ Configuration for torch.compile. @@ -1957,6 +1990,21 @@ class TorchCompileConfig(StrictBaseModel): default=False, description="Enable piecewise CUDA graph in torch.compile.") + capture_num_tokens: Optional[List[int]] = Field( + default=None, + description= + "List of num of tokens to capture the piecewise CUDA graph for. If not provided, the number of tokens will be the same as cuda_graph_config.batch_sizes." + ) + + @field_validator('capture_num_tokens') + @classmethod + def validate_capture_num_tokens(cls, v): + if v is None: + return v + if any(t <= 0 for t in v): + raise ValueError("capture_num_tokens must contain positive ints.") + return sorted(set(v), reverse=True) + enable_userbuffers: bool = Field( default=True, description= @@ -2029,10 +2077,10 @@ class TorchLlmArgs(BaseLlmArgs): "If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc.", status="beta") - enable_trtllm_sampler: bool = Field( - default=False, + sampler_type: Union[str, SamplerType] = Field( + default=SamplerType.auto, description= - "If true, will use the TRTLLM sampler instead of the PyTorch sampler. The TRTLLM sampler has a wide coverage of sampling strategies.", + "The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. Defaults to auto, which will use TorchSampler unless BeamSearch is requested.", status="prototype") enable_iter_perf_stats: bool = Field( @@ -2090,14 +2138,12 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) - allreduce_strategy: Optional[ - Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', - 'LOWPRECISION', 'MNNVL']] = Field( - default='AUTO', - description="Allreduce strategy to use.", - status="beta", - ) - + allreduce_strategy: Optional[Literal[ + 'AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', + 'LOWPRECISION', 'MNNVL', + 'NCCL_SYMMETRIC']] = Field(default='AUTO', + description="Allreduce strategy to use.", + status="beta") checkpoint_loader: Optional[object] = Field( default=None, description="The checkpoint loader to use for this LLM instance.", @@ -2320,8 +2366,9 @@ class TorchLlmArgs(BaseLlmArgs): attn_backend=self.attn_backend, moe_backend=self.moe_config.backend, enable_mixed_sampler=self.enable_mixed_sampler, - enable_trtllm_sampler=self.enable_trtllm_sampler, + sampler_type=self.sampler_type, kv_cache_dtype=self.kv_cache_config.dtype, + mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype, enable_iter_perf_stats=self.enable_iter_perf_stats, enable_iter_req_stats=self.enable_iter_req_stats, print_iter_log=self.print_iter_log, @@ -2336,6 +2383,10 @@ class TorchLlmArgs(BaseLlmArgs): enable_piecewise_cuda_graph if self.torch_compile_config is not None else TorchCompileConfig. model_fields['enable_piecewise_cuda_graph'].default, + torch_compile_piecewise_cuda_graph_num_tokens=self. + torch_compile_config.capture_num_tokens + if self.torch_compile_config is not None else + TorchCompileConfig.model_fields['capture_num_tokens'].default, torch_compile_enable_userbuffers=self.torch_compile_config. enable_userbuffers if self.torch_compile_config is not None else TorchCompileConfig.model_fields['enable_userbuffers'].default, @@ -2346,6 +2397,7 @@ class TorchLlmArgs(BaseLlmArgs): enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker, load_format=self.load_format, enable_min_latency=self.enable_min_latency, + moe_disable_finalize_fusion=self.moe_config.disable_finalize_fusion, stream_interval=self.stream_interval, force_dynamic_quantization=self.force_dynamic_quantization, allreduce_strategy=self.allreduce_strategy, diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index a62568a54e..b2145ac793 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -426,6 +426,15 @@ class ModelLoader: "weight_block_size"): quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES quant_config.exclude_modules = ["*eh_proj"] + elif hf_quant_config.get("quant_method") == "mxfp4": + from .._torch.model_config import ModelConfig + quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( + self.llm_args.moe_config.backend) + quant_config.group_size = 32 + quant_config.exclude_modules = [ + 'block.*.attn.out', 'block.*.mlp.gate', + 'block.*.attn.qkv', 'embedding', 'unembedding' + ] else: raise NotImplementedError( f"Unsupported quantization_config: {hf_quant_config}.") @@ -568,12 +577,12 @@ class ModelLoader: self._engine = Engine.from_dir(self._model_dir) @staticmethod - def load_hf_tokenizer( - model_dir, - trust_remote_code: bool = True, - use_fast: bool = True) -> Optional[TransformersTokenizer]: + def load_hf_tokenizer(model_dir, + trust_remote_code: bool = True, + use_fast: bool = True, + **kwargs) -> Optional[TransformersTokenizer]: if (tokenizer := load_hf_tokenizer(model_dir, trust_remote_code, - use_fast)) is not None: + use_fast, **kwargs)) is not None: return tokenizer else: logger.warning(f"Failed to load tokenizer from {model_dir}") diff --git a/tensorrt_llm/llmapi/tokenizer.py b/tensorrt_llm/llmapi/tokenizer.py index 6e5f7bbcee..7e13643fb8 100644 --- a/tensorrt_llm/llmapi/tokenizer.py +++ b/tensorrt_llm/llmapi/tokenizer.py @@ -57,6 +57,11 @@ class TransformersTokenizer(TokenizerBase): def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict: return self.tokenizer.batch_encode_plus(texts, *args, **kwargs) + def get_chat_template(self, + chat_template: Optional[str] = None, + tools: Optional[List[Dict]] = None) -> str: + return self.tokenizer.get_chat_template(chat_template, tools) + def apply_chat_template( self, conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], *args, @@ -330,7 +335,8 @@ def _llguidance_tokenizer_info(tokenizer): def load_hf_tokenizer(model_dir: str, trust_remote_code: bool = True, - use_fast: bool = True) -> Optional[TransformersTokenizer]: + use_fast: bool = True, + **kwargs) -> Optional[TransformersTokenizer]: ''' Load a tokenizer from a Hugging Face model directory. Args: @@ -349,7 +355,11 @@ def load_hf_tokenizer(model_dir: str, padding_side='left', truncation_side='left', trust_remote_code=trust_remote_code, - use_fast=use_fast) + use_fast=use_fast, + **kwargs) - except Exception: + except Exception as e: + logger.warning( + f"Failed to load hf tokenizer from {model_dir}, encounter error: {e}" + ) return None diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 8b2e516dba..6500084190 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -3,6 +3,7 @@ import collections import hashlib import io import os +import re import sys import tempfile import threading @@ -508,8 +509,10 @@ def generate_api_docs_as_docstring(model: Type[BaseModel], type_str = str(type_hints[field_name]) type_str = type_str.replace("typing.", "") # Extract just the class name from full class path - if "<class '" in type_str: - type_str = type_str[8:-2] + for regex in [r"<class '([^']+)'>", r"<enum '([^']+)'>"]: + if (match := re.match(regex, type_str)) is not None: + type_str = match.group(1) + break else: type_str = field_type or 'Any' diff --git a/tensorrt_llm/lora_helper.py b/tensorrt_llm/lora_helper.py new file mode 100644 index 0000000000..37f5d534f7 --- /dev/null +++ b/tensorrt_llm/lora_helper.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +from ._utils import DictConversion + + +def get_missing_qkv_modules_from_lora_modules( + lora_target_modules: List[str]) -> List[str]: + """Get missing QKV modules from LoRA target modules. + + In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or + all disabled at the same time. However, some lora checkpoints (e.g. BART) only contain two of them, + so we use zero tensor to fill the missing ones. + """ + missing_qkv_modules = [] + if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]): + for lora_module in ["attn_q", "attn_k", "attn_v"]: + if lora_module not in lora_target_modules: + missing_qkv_modules.append(lora_module) + if any(x in lora_target_modules + for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]): + for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]: + if lora_module not in lora_target_modules: + missing_qkv_modules.append(lora_module) + return missing_qkv_modules + + +def get_default_trtllm_modules_to_hf_modules(): + """Get default mapping from TensorRT-LLM module names to HuggingFace module names.""" + return { + "attn_q": "q_proj", + "attn_k": "k_proj", + "attn_v": "v_proj", + "attn_dense": "o_proj", + "mlp_h_to_4h": "gate_proj", + "mlp_4h_to_h": "down_proj", + "mlp_gate": "up_proj", + "mlp_gate_up": "gate_up_proj", + "moe_h_to_4h": "w1", + "moe_4h_to_h": "w2", + "moe_gate": "w3", + "moe_router": "gate", + } + + +def use_lora( + model, + lora_config: "LoraConfig", + trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None, +): + """Use LoRA with the given model and configuration. + + This function is a wrapper that delegates to the appropriate loading function + based on the LoRA checkpoint source. + """ + if lora_config.lora_ckpt_source == "nemo": + from .lora_manager import load_nemo_lora + load_nemo_lora(model, lora_config) + elif lora_config.lora_ckpt_source == "hf": + from .lora_manager import load_hf_lora + load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules) + else: + raise ValueError( + f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") + + +@dataclass +class LoraConfig(DictConversion): + lora_dir: List[str] = field(default_factory=list) + lora_ckpt_source: str = "hf" + max_lora_rank: int = 64 + lora_target_modules: List[str] = field(default_factory=list) + trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) + max_loras: Optional[int] = None + max_cpu_loras: Optional[int] = None + + def __post_init__(self): + assert self.lora_ckpt_source in [ + "hf", "nemo" + ], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}" + ) + + @property + def missing_qkv_modules(self) -> List[str]: + return get_missing_qkv_modules_from_lora_modules( + self.lora_target_modules) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index f2e3204716..7440715474 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -5,7 +5,7 @@ import re import tarfile import warnings from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union @@ -16,8 +16,13 @@ import yaml from tensorrt_llm.bindings import internal as tb_internal -from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy +from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy from .layers.linear import ColumnLinear +from .lora_helper import ( + LoraConfig, + get_default_trtllm_modules_to_hf_modules, + get_missing_qkv_modules_from_lora_modules, +) from .mapping import Mapping from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp @@ -232,26 +237,6 @@ def norm_dora_magnitude( return norm_m -@dataclass -class LoraConfig(DictConversion): - lora_dir: List[str] = field(default_factory=list) - lora_ckpt_source: str = "hf" - max_lora_rank: int = 64 - lora_target_modules: List[str] = field(default_factory=list) - trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) - max_loras: int | None = None - max_cpu_loras: int | None = None - - def __post_init__(self): - assert self.lora_ckpt_source in ["hf", "nemo"], ( - f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}" - ) - - @property - def missing_qkv_modules(self) -> List[str]: - return LoraManager.get_missing_qkv_modules(self.lora_target_modules) - - @dataclass class LoraModelConfig: lora_target_modules: list[str] @@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig): lora_config.lora_target_modules = lora_loader.lora_target_modules -def get_default_trtllm_modules_to_hf_modules(): - return { - "attn_q": "q_proj", - "attn_k": "k_proj", - "attn_v": "v_proj", - "attn_dense": "o_proj", - "mlp_h_to_4h": "gate_proj", - "mlp_4h_to_h": "down_proj", - "mlp_gate": "up_proj", - "mlp_gate_up": "gate_up_proj", - "moe_h_to_4h": "w1", - "moe_4h_to_h": "w2", - "moe_gate": "w3", - "moe_router": "gate", - } - - def load_torch_hf_lora(lora_config: LoraConfig): """This is a shortned version of load_hf_lora that is used for torch models. @@ -628,19 +596,6 @@ def load_hf_lora( ).to(torch_dtype) -def use_lora( - model, - lora_config: LoraConfig, - trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None, -): - if lora_config.lora_ckpt_source == "nemo": - load_nemo_lora(model, lora_config) - elif lora_config.lora_ckpt_source == "hf": - load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules) - else: - raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") - - def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]: """Unpack model config and weights from a NeMo .nemo archive file. @@ -762,21 +717,8 @@ class LoraManager(object): ) @staticmethod - def get_missing_qkv_modules(lora_target_modules): - # In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or - # all disabled at the same time. - # However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor - # to fill the missing ones. - missing_qkv_modules = [] - if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]): - for lora_module in ["attn_q", "attn_k", "attn_v"]: - if lora_module not in lora_target_modules: - missing_qkv_modules.append(lora_module) - if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]): - for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]: - if lora_module not in lora_target_modules: - missing_qkv_modules.append(lora_module) - return missing_qkv_modules + def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]: + return get_missing_qkv_modules_from_lora_modules(lora_target_modules) @property def missing_qkv_modules(self) -> List[str]: diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index f78fe093f7..cfc997b786 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -12,11 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from enum import IntEnum from typing import List import torch +class CpType(IntEnum): + # CP type for ulysses parallelism + ULYSSES = 0 + # CP type for star attention + STAR = 1 + # CP type for ring attention + RING = 2 + # CP type for helix parallelism + HELIX = 3 + + class Mapping(object): ''' A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2 @@ -135,58 +147,70 @@ class Mapping(object): if moe_cluster_size == -1: moe_cluster_size = 1 + cp_type = CpType.ULYSSES if cp_config is None else cp_config.get( + "cp_type", CpType.ULYSSES) + moe_world_size = tp_size if cp_type == CpType.ULYSSES else tp_size * cp_size + if moe_tp_size == -1 and moe_ep_size == -1: - moe_tp_size = tp_size // moe_cluster_size + moe_tp_size = moe_world_size // moe_cluster_size moe_ep_size = 1 elif moe_tp_size == -1: - moe_tp_size = tp_size // (moe_ep_size * moe_cluster_size) + moe_tp_size = moe_world_size // (moe_ep_size * moe_cluster_size) elif moe_ep_size == -1: - moe_ep_size = tp_size // (moe_tp_size * moe_cluster_size) + moe_ep_size = moe_world_size // (moe_tp_size * moe_cluster_size) if attn_tp_size == -1 and attn_cp_size == -1: - # fallback to ulysses - attn_tp_size = tp_size * cp_size - attn_cp_size = 1 + if cp_type == CpType.ULYSSES: + # fallback to ulysses + attn_tp_size = tp_size * cp_size + attn_cp_size = 1 + else: + # fallback to helix + attn_tp_size = tp_size + attn_cp_size = cp_size elif attn_tp_size == -1: - attn_tp_size = cp_size * tp_size // attn_cp_size + attn_tp_size = (tp_size * cp_size) // attn_cp_size elif attn_cp_size == -1: - attn_cp_size = cp_size * tp_size // attn_tp_size + attn_cp_size = (tp_size * cp_size) // attn_tp_size - if attn_cp_size != 1: + if attn_cp_size != 1 and cp_type == CpType.ULYSSES: raise ValueError( - f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}." + f"attn_cp_size must be 1 for now for ulysses, but got {attn_tp_size}, {attn_cp_size}." ) if auto_parallel: - if tp_size != 1 or pp_size != 1 or tp_size != 1: + if tp_size != 1 or pp_size != 1 or cp_size != 1: raise ValueError( - f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." - ) + "When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, " + f"but got {tp_size}, {pp_size}, {cp_size}.") else: if tp_size * pp_size * cp_size != world_size: raise ValueError( - f"world_size must equal to tp_size * pp_size * cp_size, but got {world_size} != {tp_size} * {pp_size} * {cp_size}." + "world_size must equal to tp_size * pp_size * cp_size, " + f"but got {world_size} != {tp_size} * {pp_size} * {cp_size}." ) moe_tp_ep_size = moe_tp_size * moe_ep_size moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size - if moe_tp_cluster_ep_size != tp_size: + if moe_tp_cluster_ep_size != moe_world_size: raise ValueError( - f"tp_size must equal to moe_tp_size * moe_ep_size * moe_cluster_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size} * {moe_cluster_size}" - ) + "moe_tp_size * moe_ep_size * moe_cluster_size must equal to moe_world_size, " + f"but got {moe_tp_cluster_ep_size} != {moe_world_size}") attn_tp_cp_size = attn_tp_size * attn_cp_size if attn_tp_cp_size != tp_size * cp_size: raise ValueError( - f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}" + "tp_size * cp_size must equal to attn_tp_size * attn_cp_size, " + f"but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}" ) - if moe_ep_size != 1 and cp_size > 1: - raise NotImplementedError("CP don't support MoE tp/ep yet") + if moe_ep_size != 1 and cp_size > 1 and cp_type != CpType.HELIX: + raise NotImplementedError( + f"CP {cp_type} doesn't support MoE tp/ep yet") self.tp_size = tp_size self.cp_size = cp_size @@ -275,6 +299,7 @@ class Mapping(object): and self.moe_ep_size == other.moe_ep_size and self.attn_tp_size == other.attn_tp_size and self.attn_cp_size == other.attn_cp_size + and self.cp_config == other.cp_config and self.auto_parallel == other.auto_parallel) def __hash__(self): @@ -290,6 +315,8 @@ class Mapping(object): self.moe_ep_size, self.attn_tp_size, self.attn_cp_size, + # note: we do not allow updating cp_config after initialization + tuple(sorted(self.cp_config.items())), self.auto_parallel, )) @@ -372,8 +399,17 @@ class Mapping(object): def local_rank(self): return self.rank % self.gpus_per_node - def has_cp(self): - return self.cp_size > 1 + @property + def dp_size(self): + return self.tp_size if self.enable_attention_dp else 1 + + def has_cp_ulysses(self): + return self.cp_size > 1 and self.cp_config.get( + "cp_type") == CpType.ULYSSES + + def has_cp_helix(self): + return self.cp_size > 1 and self.cp_config.get( + "cp_type") == CpType.HELIX def get_node_rank(self, rank: int): return rank // self.gpus_per_node @@ -411,6 +447,29 @@ class Mapping(object): p = p - self.world_size return p + def is_last_cp_rank(self): + return self.cp_rank == self.cp_size - 1 + + def is_first_cp_rank(self): + return self.cp_rank == 0 + + def has_cp(self): + return self.cp_size > 1 + + def prev_cp_rank(self): + p = self.rank - self.tp_size + if p // (self.tp_size * self.cp_size) < self.rank // (self.tp_size * + self.cp_size): + return p + self.tp_size * self.cp_size + return p + + def next_cp_rank(self): + p = self.rank + self.tp_size + if p // (self.tp_size * self.cp_size) > self.rank // (self.tp_size * + self.cp_size): + return p - self.tp_size * self.cp_size + return p + def has_moe_cluster(self): return self.moe_cluster_size > 1 @@ -449,5 +508,6 @@ class Mapping(object): 'moe_ep_size': self.moe_ep_size, 'attn_tp_size': self.attn_tp_size, 'attn_cp_size': self.attn_cp_size, + 'cp_config': self.cp_config, 'auto_parallel': self.auto_parallel, } diff --git a/tensorrt_llm/metrics/__init__.py b/tensorrt_llm/metrics/__init__.py new file mode 100644 index 0000000000..f68d9f698a --- /dev/null +++ b/tensorrt_llm/metrics/__init__.py @@ -0,0 +1,4 @@ +from .collector import * +from .enums import * + +__all__ = ["MetricsCollector", "MetricNames", "RequestEventTiming"] diff --git a/tensorrt_llm/metrics/collector.py b/tensorrt_llm/metrics/collector.py new file mode 100644 index 0000000000..952529393c --- /dev/null +++ b/tensorrt_llm/metrics/collector.py @@ -0,0 +1,105 @@ +"""Utilities for Prometheus Metrics Collection.""" + +import time +from typing import Dict, Optional, Union + +from .enums import MetricNames + + +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/engine/metrics.py#L30 +class MetricsCollector: + labelname_finish_reason = "finished_reason" + + def __init__(self, labels: Dict[str, str]) -> None: + from prometheus_client import Counter, Histogram + self.last_log_time = time.time() + self.labels = labels + + self.finish_reason_label = { + MetricsCollector.labelname_finish_reason: "unknown" + } + self.labels_with_finished_reason = { + **self.labels, + **self.finish_reason_label + } + + self.counter_request_success = Counter( + name="request_success_total", + documentation="Count of successfully processed requests.", + labelnames=self.labels_with_finished_reason.keys()) + + self.histogram_e2e_time_request = Histogram( + name="e2e_request_latency_seconds", + documentation="Histogram of end to end request latency in seconds.", + buckets=[ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + ], + labelnames=self.labels.keys()) + + self.histogram_time_to_first_token = Histogram( + name="time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, + 2560.0 + ], + labelnames=self.labels.keys()) + + self.histogram_time_per_output_token = Histogram( + name="time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ], + labelnames=self.labels.keys()) + + self.histogram_queue_time_request = Histogram( + name="request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + buckets=[ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + ], + labelnames=self.labels.keys()) + + def _label_merge(self, labels: Dict[str, str]) -> Dict[str, str]: + if labels is None or len(labels) == 0: + return self.labels + return {**self.labels, **labels} + + def _log_counter(self, counter, labels: Dict[str, str], + data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self._label_merge(labels)).inc(data) + + def _log_histogram(self, histogram, data: Union[int, float]) -> None: + # Convenience function for logging to histogram. + histogram.labels(**self.labels).observe(data) + + def log_request_success(self, data: Union[int, float], + labels: Dict[str, str]) -> None: + self._log_counter(self.counter_request_success, labels, data) + self.last_log_time = time.time() + + def log_histogram(self, data: Optional[dict[str, float]]) -> None: + if e2e := data.get(MetricNames.E2E, 0): + self._log_histogram(self.histogram_e2e_time_request, e2e) + if ttft := data.get(MetricNames.TTFT, 0): + self._log_histogram(self.histogram_time_to_first_token, ttft) + if tpot := data.get(MetricNames.TPOT, 0): + self._log_histogram(self.histogram_time_per_output_token, tpot) + if request_queue_time := data.get(MetricNames.REQUEST_QUEUE_TIME, 0): + self._log_histogram(self.histogram_queue_time_request, + request_queue_time) + self.last_log_time = time.time() + + def log_metrics_dict(self, metrics_dict: dict[str, float]) -> None: + if finish_reason := metrics_dict.get( + MetricsCollector.labelname_finish_reason): + self.log_request_success( + 1, {MetricsCollector.labelname_finish_reason: finish_reason}) + self.log_histogram(metrics_dict) diff --git a/tensorrt_llm/metrics/enums.py b/tensorrt_llm/metrics/enums.py new file mode 100644 index 0000000000..5ce982281b --- /dev/null +++ b/tensorrt_llm/metrics/enums.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class MetricNames(Enum): + TTFT = "ttft" + TPOT = "tpot" + E2E = "e2e" + REQUEST_QUEUE_TIME = "request_queue_time" + + +class RequestEventTiming(Enum): + ARRIVAL_TIME = "arrival_time" + FIRST_TOKEN_TIME = "first_token_time" # nosec: B105 + FIRST_SCHEDULED_TIME = "first_scheduled_time" + LAST_TOKEN_TIME = "last_token_time" # nosec: B105 diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index 338f16c54a..be3c5afc49 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -20,7 +20,8 @@ import tensorrt as trt import torch from tensorrt_llm._common import default_net -from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch +from tensorrt_llm._utils import (numpy_to_torch, pad_vocab_size, + str_dtype_to_torch) from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, MLPType, PositionEmbeddingType, Tensor, assertion, cast, gather_last_token_logits, @@ -35,9 +36,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, LanguageAdapterConfig, LayerNorm, LoraParams, PromptTuningEmbedding, RmsNorm) # yapf: enable -from tensorrt_llm.lora_manager import (LoraConfig, - get_default_trtllm_modules_to_hf_modules, - use_lora) +from tensorrt_llm.lora_helper import (LoraConfig, + get_default_trtllm_modules_to_hf_modules, + use_lora) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel from tensorrt_llm.module import Module, ModuleList @@ -1156,9 +1157,11 @@ class DecoderModel(PretrainedModel): self.transformer.assign_module(decoder_layers, "layers") if self.mapping.is_last_pp_rank(): + vocab_size_padded = pad_vocab_size(self.config.vocab_size, + self.config.mapping.tp_size) self.lm_head = ColumnLinear( self.config.hidden_size, - self.config.vocab_size, + vocab_size_padded, bias=False if not hasattr(self.config, "has_lm_head_bias") else self.config.has_lm_head_bias, dtype=self.config.dtype, @@ -1208,7 +1211,6 @@ class DecoderModel(PretrainedModel): config.set_if_not_exist('num_buckets', None) config.set_if_not_exist('max_distance', None) config.set_if_not_exist('relative_attention', False) - config.set_if_not_exist('residual_scaling', 1.0) def forward(self, decoder_input_ids: Tensor, diff --git a/tensorrt_llm/models/gemma/config.py b/tensorrt_llm/models/gemma/config.py index a9c7e05d72..8e176c4ed7 100644 --- a/tensorrt_llm/models/gemma/config.py +++ b/tensorrt_llm/models/gemma/config.py @@ -52,7 +52,7 @@ class GemmaConfig(PretrainedConfig): final_logit_softcapping: Optional[float] = None, attn_logit_softcapping: Optional[float] = None, mapping: Optional[Union[Mapping, dict]] = None, - sliding_window_pattern: int = None, + _sliding_window_pattern: int = None, rope_local_base_freq: int = None, sliding_window: int = None, **kwargs, @@ -94,7 +94,7 @@ class GemmaConfig(PretrainedConfig): if self.is_gemma_2: self.attn_logit_softcapping = attn_logit_softcapping if self.is_gemma_3: - self.sliding_window_pattern = sliding_window_pattern + self._sliding_window_pattern = _sliding_window_pattern self.rope_local_base_freq = rope_local_base_freq self.sliding_window = sliding_window diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py index a80b9e771d..2091f111f8 100644 --- a/tensorrt_llm/models/gemma/model.py +++ b/tensorrt_llm/models/gemma/model.py @@ -28,7 +28,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType, from ...layers import (Attention, AttentionMaskType, AttentionParams, ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams, LoraParams, PositionEmbeddingType, RmsNorm) -from ...lora_manager import LoraConfig, use_lora +from ...lora_helper import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, @@ -74,7 +74,7 @@ class GemmaDecoderLayer(Module): gemma3_config.query_pre_attn_scalar) / math.sqrt( config.head_size) is_sliding = bool( - (layer_idx + 1) % gemma3_config.sliding_window_pattern) + (layer_idx + 1) % gemma3_config._sliding_window_pattern) rotary_base_local = config.rope_local_base_freq self.attention = Attention( diff --git a/tensorrt_llm/models/gpt/model.py b/tensorrt_llm/models/gpt/model.py index aecfcda64c..89267a90f7 100644 --- a/tensorrt_llm/models/gpt/model.py +++ b/tensorrt_llm/models/gpt/model.py @@ -21,7 +21,7 @@ from ...functional import (Tensor, is_gated_activation, non_gated_version, recv, from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, GatedMLP, LayerNorm, MoeConfig, PositionEmbeddingType) -from ...lora_manager import LoraConfig, use_lora +from ...lora_helper import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ...quantization import QuantMode diff --git a/tensorrt_llm/models/grok/model.py b/tensorrt_llm/models/grok/model.py index 8fc34349f9..9ff22cd71c 100644 --- a/tensorrt_llm/models/grok/model.py +++ b/tensorrt_llm/models/grok/model.py @@ -18,7 +18,7 @@ from ..._utils import pad_vocab_size from ...functional import Tensor, recv, send from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, MoeConfig, PositionEmbeddingType, RmsNorm) -from ...lora_manager import LoraConfig, use_lora +from ...lora_helper import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index 259f3e2f9a..2e272772ad 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -25,7 +25,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor, from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, PositionEmbeddingType, RmsNorm) -from ...lora_manager import LoraConfig, use_lora +from ...lora_helper import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ...quantization.functional import fused_layernorm diff --git a/tensorrt_llm/models/mllama/model.py b/tensorrt_llm/models/mllama/model.py index 5f9c622fa8..95a261350b 100644 --- a/tensorrt_llm/models/mllama/model.py +++ b/tensorrt_llm/models/mllama/model.py @@ -32,9 +32,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, GroupNorm, KeyValueCacheParams, LayerNorm, LoraParams, RmsNorm) -from tensorrt_llm.lora_manager import (LoraConfig, - get_default_trtllm_modules_to_hf_modules, - use_lora) +from tensorrt_llm.lora_helper import (LoraConfig, + get_default_trtllm_modules_to_hf_modules, + use_lora) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 7b2af7af15..dcc375320e 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -67,7 +67,7 @@ class Gemma2ConfigGroup: class Gemma3ConfigGroup: query_pre_attn_scalar: float final_logit_softcapping: Optional[float] - sliding_window_pattern: int + _sliding_window_pattern: int rope_local_base_freq: int sliding_window: int @@ -139,6 +139,7 @@ class QuantConfig: has_zero_point (bool): Whether to use zero point for quantization. Defaults to False. pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False. exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None. + mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None. """ quant_algo: Optional[QuantAlgo] = None kv_cache_quant_algo: Optional[QuantAlgo] = None @@ -149,6 +150,7 @@ class QuantConfig: has_zero_point: bool = False pre_quant_scale: bool = False exclude_modules: Optional[List[str]] = None + mamba_ssm_cache_dtype: Optional[str] = None @cached_property def quant_mode(self) -> QuantModeWrapper: @@ -1817,7 +1819,7 @@ def preprocess_perlayer_weights(weights, weights[new_name] = weights[name] weights[ new_name + - "_interleaved"] = torch.ops.trtllm.nvfp4_block_scale_interleave( + "_interleaved"] = torch.ops.trtllm.block_scale_interleave( weights[name].view(fp4_utils.float4_sf_dtype).cpu( ).contiguous()).reshape(nrows, ncols).view( fp4_utils.float4_sf_dtype) diff --git a/tensorrt_llm/models/phi/model.py b/tensorrt_llm/models/phi/model.py index 6e6dd3579b..9c90e114e9 100644 --- a/tensorrt_llm/models/phi/model.py +++ b/tensorrt_llm/models/phi/model.py @@ -20,7 +20,7 @@ from ..._utils import pad_vocab_size from ...functional import Tensor from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear, Embedding, LayerNorm) -from ...lora_manager import LoraConfig, use_lora +from ...lora_helper import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, diff --git a/tensorrt_llm/models/phi3/model.py b/tensorrt_llm/models/phi3/model.py index 5f058f147e..5bdc24f8ed 100644 --- a/tensorrt_llm/models/phi3/model.py +++ b/tensorrt_llm/models/phi3/model.py @@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType, Tensor from ...layers import (MLP, MOE, Attention, AttentionMaskType, BlockSparseAttnParams, ColumnLinear, Embedding, LayerNorm, MoeConfig, RmsNorm) -from ...lora_manager import LoraConfig, use_lora +from ...lora_helper import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, diff --git a/tensorrt_llm/models/qwen/model.py b/tensorrt_llm/models/qwen/model.py index 0eb6e8ac44..f32a4036d8 100644 --- a/tensorrt_llm/models/qwen/model.py +++ b/tensorrt_llm/models/qwen/model.py @@ -26,8 +26,8 @@ from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, GatedMLP, RmsNorm, SharedMoE) from ...layers.moe import MOEWeightWrapper from ...logger import logger -from ...lora_manager import (LoraConfig, - get_default_trtllm_modules_to_hf_modules, use_lora) +from ...lora_helper import (LoraConfig, + get_default_trtllm_modules_to_hf_modules, use_lora) from ...mapping import Mapping from ...module import Module from ...quantization import QuantAlgo diff --git a/tensorrt_llm/quantization/functional.py b/tensorrt_llm/quantization/functional.py index 84dc1b74a5..30986412c8 100644 --- a/tensorrt_llm/quantization/functional.py +++ b/tensorrt_llm/quantization/functional.py @@ -950,10 +950,12 @@ def symmetric_quantize_last_axis_of_batched_matrix(weight, quant_mode): return qweight, scale -def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor, - quant_mode: torch.dtype, - act_dtype: torch.dtype, - sm_: int = -1) -> torch.Tensor: +def preprocess_weights_for_mixed_gemm( + tensor: torch.Tensor, + quant_mode: torch.dtype, + act_dtype: torch.dtype, + sm_: int = -1, + do_weight_interleave: bool = True) -> torch.Tensor: sm_ = sm_ if sm_ > 0 else get_sm_version() if len(tensor.shape) == 2: tensor = tensor.unsqueeze(0) @@ -988,13 +990,12 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor, assert (num_rows % B_ROWS_PER_MMA == 0) assert (num_cols % MMA_SHAPE_N == 0) - row_idx_list = [ - (row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA + - permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][row_idx % - B_ROWS_PER_MMA] - for row_idx in range(num_rows) - ] - tensor = tensor[:, row_idx_list, :] + if do_weight_interleave: + row_idx_list = [(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA + + permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][ + row_idx % B_ROWS_PER_MMA] + for row_idx in range(num_rows)] + tensor = tensor[:, row_idx_list, :] # subbyte_transpose original_shape = tensor.shape @@ -1010,42 +1011,63 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor, else: tensor = tensor.permute(0, 2, 1).reshape(original_shape) - # interleave_column_major_tensor - interleave = BITS_PER_ELT_A // BITS_PER_ELT_B - if interleave > 1 and sm_ < 90: - rows_per_tile = 128 * 8 // BITS_PER_ELT_A - elts_in_int32 = 32 // BITS_PER_ELT_B + if do_weight_interleave: + # interleave_column_major_tensor + interleave = BITS_PER_ELT_A // BITS_PER_ELT_B + if interleave > 1 and sm_ < 90: + rows_per_tile = 128 * 8 // BITS_PER_ELT_A + elts_in_int32 = 32 // BITS_PER_ELT_B - assert (num_rows % elts_in_int32 == 0) - assert (num_rows % rows_per_tile == 0) + assert (num_rows % elts_in_int32 == 0) + assert (num_rows % rows_per_tile == 0) - tensor = tensor.reshape(num_experts, -1, interleave, - num_rows // rows_per_tile, - rows_per_tile * 4 // elts_in_int32) - tensor = tensor.permute(0, 1, 3, 2, 4).reshape(original_shape) + tensor = tensor.reshape(num_experts, -1, interleave, + num_rows // rows_per_tile, + rows_per_tile * 4 // elts_in_int32) + tensor = tensor.permute(0, 1, 3, 2, 4).reshape(original_shape) - # add_bias_and_interleave_quantized_tensor_inplace - if BITS_PER_ELT_B == 8: - tensor += -256 * (tensor > 127).byte() + 128 - tensor = tensor.reshape(-1, 4)[:, [0, 2, 1, 3]].reshape(tensor.shape) - elif BITS_PER_ELT_B == 4: - tensor = tensor.view(torch.uint8) - high_tensor = (tensor >> 4).unsqueeze(-1) - low_tensor = ((tensor << 4) >> 4).unsqueeze(-1) - new_tensor = torch.cat([low_tensor, high_tensor], - dim=-1).reshape(tensor.shape[0], tensor.shape[1], - -1) - new_tensor = new_tensor.reshape( - -1, 8)[:, [0, 2, 4, 6, 1, 3, 5, 7]].reshape(new_tensor.shape) - new_tensor += -16 * (new_tensor > 7).byte() + 8 - new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16 - tensor = new_tensor.view(torch.int8) - else: - raise NotImplementedError + # add_bias_and_interleave_quantized_tensor_inplace + if BITS_PER_ELT_B == 8: + tensor += -256 * (tensor > 127).byte() + 128 + tensor = tensor.reshape(-1, 4)[:, + [0, 2, 1, 3]].reshape(tensor.shape) + elif BITS_PER_ELT_B == 4: + tensor = tensor.view(torch.uint8) + high_tensor = (tensor >> 4).unsqueeze(-1) + low_tensor = ((tensor << 4) >> 4).unsqueeze(-1) + new_tensor = torch.cat([low_tensor, high_tensor], + dim=-1).reshape(tensor.shape[0], + tensor.shape[1], -1) + new_tensor = new_tensor.reshape( + -1, 8)[:, [0, 2, 4, 6, 1, 3, 5, 7]].reshape(new_tensor.shape) + new_tensor += -16 * (new_tensor > 7).byte() + 8 + new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16 + tensor = new_tensor.view(torch.int8) + else: + raise NotImplementedError return tensor.squeeze(0).contiguous() +def get_weight_scale_interleave_factor(interleaved_dim: int, + group_size: int = 128) -> int: + # Calculate the weight_scale interleave factor for W4A8 groupwise MoE quant + # only Hopper w4a8 does interleave for weight scale, other arch or Hopper w4a16 default to 1 + factor = 1 + if get_sm_version() == 90: + if interleaved_dim % (4 * group_size) == 0: + factor = 4 + elif interleaved_dim % (2 * group_size) == 0: + factor = 2 + elif interleaved_dim % group_size == 0: + factor = 1 + else: + raise NotImplementedError( + f"Interleaved dimension must be a multiple of group_size ({group_size}), received {interleaved_dim}." + ) + return factor + + def validate_group_size(layer): # TODO: Remove this function and its usage after W4A8-AWQ with group_size = 64 is implemented. W4A8_AWQ = 8 diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index a1c3982b72..7aa8e80800 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -2218,7 +2218,7 @@ class FP4Linear(Linear): qkv_block_scale, tllm_key.replace( 'weight', "weights_block_scaling_factor_interleaved"): - torch.ops.trtllm.nvfp4_block_scale_interleave( + torch.ops.trtllm.block_scale_interleave( qkv_block_scale.view( torch.uint8).cpu().contiguous()).reshape( qkv_block_scale.shape).view( @@ -2238,7 +2238,7 @@ class FP4Linear(Linear): elif tllm_key.endswith("weights_block_scaling_factor"): return weights elif tllm_key.endswith("weights_block_scaling_factor_interleaved"): - return torch.ops.trtllm.nvfp4_block_scale_interleave( + return torch.ops.trtllm.block_scale_interleave( weights.view(torch.uint8).cpu().contiguous()).reshape( weights.shape).view(torch.float8_e4m3fn) elif tllm_key.endswith("weights_global_scaling_factor"): @@ -2379,7 +2379,7 @@ class FP4RowLinear(RowLinear): elif tllm_key.endswith("weights_block_scaling_factor"): return weights elif tllm_key.endswith("weights_block_scaling_factor_interleaved"): - return torch.ops.trtllm.nvfp4_block_scale_interleave( + return torch.ops.trtllm.block_scale_interleave( weights.view(torch.uint8).cpu().contiguous()).reshape( weights.shape).view(torch.float8_e4m3fn) elif tllm_key.endswith("weights_global_scaling_factor"): diff --git a/tensorrt_llm/quantization/mode.py b/tensorrt_llm/quantization/mode.py index b555625f7b..a8b38d885f 100644 --- a/tensorrt_llm/quantization/mode.py +++ b/tensorrt_llm/quantization/mode.py @@ -41,6 +41,8 @@ class QuantAlgo(StrEnum, metaclass=BaseEnumMeta): MIXED_PRECISION = auto() NVFP4 = auto() W4A8_MXFP4_FP8 = auto() + W4A8_MXFP4_MXFP8 = auto() + W4A16_MXFP4 = auto() NO_QUANT = auto() @@ -90,6 +92,8 @@ class QuantMode(IntFlag): NVFP4_KV_CACHE = auto() # W4A8 MXFP4 W4A8_MXFP4_FP8 = auto() + W4A8_MXFP4_MXFP8 = auto() + W4A16_MXFP4 = auto() # The smallest power-of-two that is not used by a flag. Do not call auto() after that line. COUNT = auto() @@ -178,6 +182,16 @@ class QuantMode(IntFlag): def has_w4a8_mxfp4_fp8(self): return self._any(self.W4A8_MXFP4_FP8) + def has_w4a8_mxfp4_mxfp8(self): + return self._any(self.W4A8_MXFP4_MXFP8) + + def has_w4a16_mxfp4(self): + return self._any(self.W4A16_MXFP4) + + def has_mxfp4(self): + return self._any(self.W4A8_MXFP4_FP8 | self.W4A8_MXFP4_MXFP8 + | self.W4A16_MXFP4) + def has_weight_quant(self): return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS) @@ -189,7 +203,9 @@ class QuantMode(IntFlag): | self.W4A8_QSERVE | self.FP8_1x128_128x128 | self.NVFP4 - | self.W4A8_MXFP4_FP8) + | self.W4A8_MXFP4_FP8 + | self.W4A16_MXFP4 + | self.W4A8_MXFP4_MXFP8) if exclude_kv_cache: return has_quant @@ -225,7 +241,9 @@ class QuantMode(IntFlag): use_fp8_rowwise=False, use_nvfp4=False, use_w4a8_qserve=False, - use_w4a8_mxfp4_fp8=False): + use_w4a8_mxfp4_fp8=False, + use_w4a8_mxfp4_mxfp8=False, + use_w4a16_mxfp4=False): def raise_error(): raise ValueError(f"Unsupported combination of QuantMode args: " @@ -242,7 +260,9 @@ class QuantMode(IntFlag): f"{use_fp8_rowwise=}, " f"{use_nvfp4=}, " f"{use_w4a8_qserve=}, " - f"{use_w4a8_mxfp4_fp8=}") + f"{use_w4a8_mxfp4_fp8=}, " + f"{use_w4a8_mxfp4_mxfp8=}, " + f"{use_w4a16_mxfp4=}") # We must quantize weights when we quantize activations. if quantize_activations and not quantize_weights: @@ -300,6 +320,12 @@ class QuantMode(IntFlag): if use_w4a8_mxfp4_fp8: mode = mode | QuantMode.W4A8_MXFP4_FP8 + if use_w4a8_mxfp4_mxfp8: + mode = mode | QuantMode.W4A8_MXFP4_MXFP8 + + if use_w4a16_mxfp4: + mode = mode | QuantMode.W4A16_MXFP4 + return mode @staticmethod @@ -375,6 +401,10 @@ class QuantMode(IntFlag): quant_mode = QuantMode.from_description(use_nvfp4=True) elif quant_algo == QuantAlgo.W4A8_MXFP4_FP8: quant_mode = QuantMode.from_description(use_w4a8_mxfp4_fp8=True) + elif quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: + quant_mode = QuantMode.from_description(use_w4a8_mxfp4_mxfp8=True) + elif quant_algo == QuantAlgo.W4A16_MXFP4: + quant_mode = QuantMode.from_description(use_w4a16_mxfp4=True) else: quant_mode = QuantMode(0) @@ -409,6 +439,10 @@ class QuantMode(IntFlag): self.has_nvfp4(), 'enable_w4a8_mxfp4_fp8': self.has_w4a8_mxfp4_fp8(), + 'enable_w4a8_mxfp4_mxfp8': + self.has_w4a8_mxfp4_mxfp8(), + 'enable_w4a16_mxfp4': + self.has_w4a16_mxfp4(), 'fp8_kv_cache': self.has_fp8_kv_cache(), 'use_weight_only': diff --git a/tensorrt_llm/quantization/utils/fp4_utils.py b/tensorrt_llm/quantization/utils/fp4_utils.py index 43bf4e8c98..c62a4e2527 100644 --- a/tensorrt_llm/quantization/utils/fp4_utils.py +++ b/tensorrt_llm/quantization/utils/fp4_utils.py @@ -49,8 +49,7 @@ def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor: to [r0, rN/2, r1, rN/2+1, ..., r(N/2-1), r(N-1)] """ - assert x.dim() == 2, f"x should be a 2D tensor, not {x.dim()}" - M, K = x.shape + M = x.shape[0] assert M % 2 == 0, f"x.shape[0] must be even, not {M}" row_indices = torch.arange(M, dtype=torch.long) @@ -120,11 +119,8 @@ def get_shuffle_matrix_a_row_indices(input_tensor: torch.Tensor, - We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes). - Instead, we purely reorder rows in a standard PyTorch shape [M, K]. """ - assert input_tensor.dim( - ) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" - - # M, K from the input - M, K = input_tensor.shape + # M from the input + M = input_tensor.shape[0] # Choose block size 16 or 32 shuffle_block_size = get_shuffle_block_size(epilogue_tile_m) @@ -168,7 +164,7 @@ def get_shuffle_matrix_sf_a_row_indices( num_elts_per_sf: int = 16) -> torch.Tensor: assert input_tensor.dtype == float4_sf_dtype - assert num_elts_per_sf == 16 + assert num_elts_per_sf == 16 or num_elts_per_sf == 32 assert input_tensor.dim( ) == 2, f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" @@ -207,4 +203,4 @@ def shuffle_matrix_sf_a( input_tensor, row_indices.to(input_tensor.device)) # 128x4 - return torch.ops.trtllm.nvfp4_block_scale_interleave(w_shuffled) + return torch.ops.trtllm.block_scale_interleave(w_shuffled) diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 19bd24671d..4c486a1511 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -302,6 +302,8 @@ def _silu_and_mul_post_quant_kernel( def silu_and_mul_masked_post_quant_fwd( + output: torch.Tensor, + output_scale: torch.Tensor, input: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, @@ -328,18 +330,6 @@ def silu_and_mul_masked_post_quant_fwd( g, m, k = input.shape k = k // 2 - # Create output - output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda") - - # Create output scale - alignment = 4 - scale_k = ceil_div(k, quant_group_size) - m_padded = align(m, alignment) - scale_k_padded = align(scale_k, alignment) - output_scale = torch.zeros((g, scale_k_padded // 4, m_padded), - dtype=torch.int32, - device='cuda') - # Get block/grid/stage/warp expert_num = len(masked_m) @@ -382,7 +372,7 @@ def silu_and_mul_masked_post_quant_fwd( g, tma_stride_check=True, ) - return output, output_scale + return output_scale @triton.jit diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index ea5ce01f62..361c0fc0c0 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -346,17 +346,20 @@ class SamplingParams: if self.pad_id is None: self.pad_id = self.end_id + def _encode(tokenizer, text, add_special_tokens): + try: + return tokenizer.encode(text, add_special_tokens=add_special_tokens) + except TypeError: + # For tiktokenizer, the encode method does not have add_special_tokens argument + return tokenizer.encode(text) + if self.bad is not None: strs = [self.bad] if isinstance(self.bad, str) else self.bad - self._bad_word_ids = [ - tokenizer.encode(s, add_special_tokens=add_special_tokens) for s in strs - ] + self._bad_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs] if self.stop is not None: strs = [self.stop] if isinstance(self.stop, str) else self.stop - self._stop_word_ids = [ - tokenizer.encode(s, add_special_tokens=add_special_tokens) for s in strs - ] + self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs] return self diff --git a/tensorrt_llm/serialization.py b/tensorrt_llm/serialization.py index b295e9f0b0..e5a122f892 100644 --- a/tensorrt_llm/serialization.py +++ b/tensorrt_llm/serialization.py @@ -13,6 +13,105 @@ BASE_EXAMPLE_CLASSES = { "AssertionError", "RuntimeError" ], # each Exception Error class needs to be added explicitly "collections": ["OrderedDict"], + "datetime": ["timedelta"], + "pathlib": ["PosixPath"], + "llmapi.run_llm_with_postproc": ["perform_faked_oai_postprocess" + ], # only used in tests + ### starting import of torch models classes. They are used in test_llm_multi_gpu.py. + "tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"], + "tensorrt_llm._torch.models.modeling_bert": + ["BertForSequenceClassification"], + "tensorrt_llm._torch.models.modeling_clip": ["CLIPVisionModel"], + "tensorrt_llm._torch.models.modeling_deepseekv3": ["DeepseekV3ForCausalLM"], + "tensorrt_llm._torch.models.modeling_gemma3": ["Gemma3ForCausalLM"], + "tensorrt_llm._torch.models.modeling_hyperclovax": ["HCXVisionForCausalLM"], + "tensorrt_llm._torch.models.modeling_llama": [ + "Eagle3LlamaForCausalLM", + "LlamaForCausalLM", + "Llama4ForCausalLM", + "Llama4ForConditionalGeneration", + ], + "tensorrt_llm._torch.models.modeling_llava_next": ["LlavaNextModel"], + "tensorrt_llm._torch.models.modeling_mistral": ["MistralForCausalLM"], + "tensorrt_llm._torch.models.modeling_mixtral": ["MixtralForCausalLM"], + "tensorrt_llm._torch.models.modeling_mllama": + ["MllamaForConditionalGeneration"], + "tensorrt_llm._torch.models.modeling_nemotron": ["NemotronForCausalLM"], + "tensorrt_llm._torch.models.modeling_nemotron_h": ["NemotronHForCausalLM"], + "tensorrt_llm._torch.models.modeling_nemotron_nas": + ["NemotronNASForCausalLM"], + "tensorrt_llm._torch.models.modeling_qwen": + ["Qwen2ForCausalLM", "Qwen2ForProcessRewardModel", "Qwen2ForRewardModel"], + "tensorrt_llm._torch.models.modeling_qwen2vl": + ["Qwen2VLModel", "Qwen2_5_VLModel"], + "tensorrt_llm._torch.models.modeling_qwen3": ["Qwen3ForCausalLM"], + "tensorrt_llm._torch.models.modeling_qwen3_moe": ["Qwen3MoeForCausalLM"], + "tensorrt_llm._torch.models.modeling_qwen_moe": ["Qwen2MoeForCausalLM"], + "tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"], + "tensorrt_llm._torch.models.modeling_vila": ["VilaModel"], + "tensorrt_llm._torch.models.modeling_gpt_oss": ["GptOssForCausalLM"], + ### ending import of torch models classes + "tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"], + "tensorrt_llm._torch.pyexecutor.llm_request": + ["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"], + "tensorrt_llm._torch.speculative.mtp": ["MTPConfig"], + "tensorrt_llm._torch.speculative.interface": ["SpeculativeDecodingMode"], + "tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"], + "tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"], + "tensorrt_llm.auto_parallel.cluster_info": + ["ClusterInfo", "MathThroughput"], + "tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"], + "tensorrt_llm.bindings.executor": [ + "BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy", + "ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig", + "ExecutorConfig", "ExtendedRuntimePerfKnobConfig", "Response", "Result", + "FinishReason", "KvCacheConfig", "KvCacheTransferMode", + "KvCacheRetentionConfig", + "KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig", + "SchedulerConfig" + ], + "tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig"], + "tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"], + "tensorrt_llm.builder": ["BuildConfig"], + "tensorrt_llm.disaggregated_params": ["DisaggregatedParams"], + "tensorrt_llm.inputs.multimodal": ["MultimodalInput"], + "tensorrt_llm.executor.postproc_worker": [ + "PostprocArgs", "PostprocParams", "PostprocWorkerConfig", + "PostprocWorker.Input", "PostprocWorker.Output" + ], + "tensorrt_llm.executor.request": [ + "CancellingRequest", "GenerationRequest", "LoRARequest", + "PromptAdapterRequest" + ], + "tensorrt_llm.executor.result": [ + "CompletionOutput", "DetokenizedGenerationResultBase", + "GenerationResult", "GenerationResultBase", "IterationResult", + "Logprob", "LogProbsResult", "ResponseWrapper" + ], + "tensorrt_llm.executor.utils": ["ErrorResponse", "WorkerCommIpcAddrs"], + "tensorrt_llm.executor.worker": ["GenerationExecutorWorker", "worker_main"], + "tensorrt_llm.llmapi.llm_args": [ + "_ModelFormatKind", "_ParallelConfig", "CalibConfig", + "CapacitySchedulerPolicy", "KvCacheConfig", "LookaheadDecodingConfig", + "TrtLlmArgs", "SchedulerConfig", "LoadFormat", "DynamicBatchConfig" + ], + "tensorrt_llm.llmapi.mpi_session": ["RemoteTask"], + "tensorrt_llm.llmapi.llm_utils": + ["CachedModelLoader._node_build_task", "LlmBuildStats"], + "tensorrt_llm.llmapi.tokenizer": ["TransformersTokenizer"], + "tensorrt_llm.lora_manager": ["LoraConfig"], + "tensorrt_llm.mapping": ["Mapping"], + "tensorrt_llm.models.modeling_utils": + ["QuantConfig", "SpeculativeDecodingMode"], + "tensorrt_llm.plugin.plugin": ["PluginConfig"], + "tensorrt_llm.sampling_params": + ["SamplingParams", "GuidedDecodingParams", "GreedyDecodingParams"], + "tensorrt_llm.serve.postprocess_handlers": [ + "chat_response_post_processor", "chat_stream_post_processor", + "completion_stream_post_processor", + "completion_response_post_processor", "CompletionPostprocArgs", + "ChatPostprocArgs" + ], "torch._utils": ["_rebuild_tensor_v2"], "torch.storage": ["_load_from_bytes"], } diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index d90578ce36..1b1e15ec62 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio import os +import re import signal import traceback from contextlib import asynccontextmanager @@ -13,6 +14,7 @@ import uvicorn from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.routing import Mount from transformers import AutoConfig, AutoProcessor from tensorrt_llm._tensorrt_engine import LLM @@ -25,6 +27,7 @@ from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger +from tensorrt_llm.metrics.collector import MetricsCollector from tensorrt_llm.serve.chat_utils import (check_multiple_response, parse_chat_messages_coroutines) from tensorrt_llm.serve.metadata_server import create_metadata_server @@ -42,7 +45,7 @@ from tensorrt_llm.serve.postprocess_handlers import ( completion_stream_post_processor) from tensorrt_llm.version import __version__ as VERSION -from .._utils import nvtx_mark +from .._utils import nvtx_mark, set_prometheus_multiproc_dir # yapf: enale TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -78,6 +81,13 @@ class OpenAIServer: self.model = model_dir.name else: self.model = model + self.metrics_collector = None + if self.llm.args.return_perf_metrics: + set_prometheus_multiproc_dir() + self.metrics_collector = MetricsCollector({ + "model_name": "undefined", + "engine_type": "undefined" + }) @asynccontextmanager async def lifespan(app: FastAPI): @@ -151,6 +161,32 @@ class OpenAIServer: self.app.add_api_route("/v1/chat/completions", self.openai_chat, methods=["POST"]) + if self.llm.args.return_perf_metrics: + # register /prometheus/metrics + self.mount_metrics() + + def mount_metrics(self): + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + from prometheus_fastapi_instrumentator import Instrumentator + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + Instrumentator( + should_group_status_codes=False, + should_respect_env_var=True, + excluded_handlers=[ + ".*" + ], + registry=registry, + ).add().instrument(self.app).expose(self.app) + metrics_app = make_asgi_app(registry=registry) + metrics_route = Mount("/prometheus/metrics", metrics_app) + metrics_route.path_regex = re.compile("^/prometheus/metrics(?P<path>.*)$") + self.app.routes.append(metrics_route) async def health(self) -> Response: return Response(status_code=200) @@ -228,6 +264,8 @@ class OpenAIServer: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args async for res in promise: pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + if res.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(res.metrics_dict) for pp_res in pp_results: yield pp_res yield "data: [DONE]\n\n" @@ -245,6 +283,8 @@ class OpenAIServer: # Add prompt_tokens_ids to the response if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": chat_response.prompt_token_ids = promise.prompt_token_ids + if promise.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(promise.metrics_dict) return chat_response try: @@ -337,6 +377,8 @@ class OpenAIServer: if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": # Include prompt token ids for context-only requests pp_result.prompt_token_ids = response.prompt_token_ids + if response.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(response.metrics_dict) return pp_result def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse: @@ -372,6 +414,8 @@ class OpenAIServer: pp_result = post_processor(output, args) else: pp_result = output.outputs[0]._postprocess_result + if output.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(output.metrics_dict) for pp_res in pp_result: yield pp_res diff --git a/tensorrt_llm/serve/scripts/backend_request_func.py b/tensorrt_llm/serve/scripts/backend_request_func.py index c65cd8e839..8959fc6406 100644 --- a/tensorrt_llm/serve/scripts/backend_request_func.py +++ b/tensorrt_llm/serve/scripts/backend_request_func.py @@ -30,6 +30,7 @@ class RequestFuncInput: extra_body: Optional[dict] = None ignore_eos: bool = False language: Optional[str] = None + multi_modal_content: Optional[dict] = None @dataclass @@ -54,7 +55,10 @@ async def async_request_trt_llm( session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith("generate_stream") + if not api_url.endswith("generate_stream"): + raise ValueError( + f"TRT-LLM API URL must end with 'generate_stream', but got: {api_url}" + ) request_session = aiohttp.ClientSession( trust_env=True, @@ -144,9 +148,10 @@ async def async_request_openai_completions( session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("completions", "profile") - ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + if not api_url.endswith(("completions", "profile")): + raise ValueError( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) request_session = aiohttp.ClientSession( trust_env=True, @@ -268,9 +273,9 @@ async def async_request_openai_chat_completions( session: Optional[aiohttp.ClientSession] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("chat/completions", "profile" - )), "OpenAI Chat Completions API URL must end with 'chat/completions'." + if not api_url.endswith(("chat/completions", "profile")): + raise ValueError( + "OpenAI Chat Completions API URL must end with 'chat/completions'.") request_session = aiohttp.ClientSession( trust_env=True, @@ -292,16 +297,12 @@ async def async_request_openai_chat_completions( [isinstance(i, int) for i in request_func_input.prompt]): payload["prompt_token_ids"] = request_func_input.prompt else: - assert isinstance(request_func_input.prompt, - str), "Prompt must be a string or a list of integers" - payload["messages"].append({ - "role": - "user", - "content": [{ - "type": "text", - "text": request_func_input.prompt - }] - }) + if not isinstance(request_func_input.prompt, str): + raise ValueError("Prompt must be a string or a list of integers") + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.extend(request_func_input.multi_modal_content) + payload["messages"].append({"role": "user", "content": content}) if streaming: payload["stream_options"] = {"include_usage": True} diff --git a/tensorrt_llm/serve/scripts/benchmark_dataset.py b/tensorrt_llm/serve/scripts/benchmark_dataset.py index 35d2744aea..02000cddba 100644 --- a/tensorrt_llm/serve/scripts/benchmark_dataset.py +++ b/tensorrt_llm/serve/scripts/benchmark_dataset.py @@ -17,18 +17,24 @@ TODO: Implement CustomDataset to parse a JSON file and convert its contents into SampleRequest instances, similar to the approach used in ShareGPT. """ +import base64 +import io import json import logging import random +import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Mapping, Optional, Union import numpy as np import pandas as pd +import torch from datasets import load_dataset +from PIL import Image from transformers import PreTrainedTokenizerBase +from tensorrt_llm.inputs.utils import convert_image_mode from tensorrt_llm.serve.scripts.benchmark_utils import download_and_cache_file logger = logging.getLogger(__name__) @@ -38,6 +44,162 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- +def timing_decorator(method_name: str): + """ + Decorator to time method execution and print the results. + + Args: + method_name: Name to display in timing output (e.g., 'load_data', 'sample') + """ + + def decorator(func): + + def wrapper(self, *args, **kwargs): + dataset_name = self.__class__.__name__ + start_time = time.perf_counter() + print(f"{dataset_name}.{method_name}() started...") + + try: + result = func(self, *args, **kwargs) + end_time = time.perf_counter() + duration = end_time - start_time + print( + f"{dataset_name}.{method_name}() completed in {duration:.4f} seconds" + ) + return result + except Exception as e: + end_time = time.perf_counter() + duration = end_time - start_time + print( + f"{dataset_name}.{method_name}() failed after {duration:.4f} seconds: {str(e)}" + ) + raise + + return wrapper + + return decorator + + +def auto_time_methods(*method_names): + """ + Class decorator that automatically applies timing to specified methods + in the class and all its subclasses. + + Usage: + @auto_time_methods("load_data", "sample") + class MyDataset(BenchmarkDataset): + def load_data(self): # Will be automatically timed + pass + def sample(self): # Will be automatically timed + pass + """ + + def class_decorator(cls): + # Store the method names that should be timed + cls._timed_methods = method_names + + # Override __init_subclass__ to automatically apply timing to subclasses + original_init_subclass = getattr(cls, '__init_subclass__', + lambda **kwargs: None) + + @classmethod + def __init_subclass__(subcls, **kwargs): + original_init_subclass(**kwargs) + + # Apply timing to the specified methods if they exist in the subclass + for method_name in method_names: + if hasattr(subcls, method_name): + original_method = getattr(subcls, method_name) + + # Only wrap if not already wrapped (check for our wrapper's signature) + if not hasattr(original_method, '_is_timed'): + timed_method = timing_decorator(method_name)( + original_method) + timed_method._is_timed = True + setattr(subcls, method_name, timed_method) + + cls.__init_subclass__ = __init_subclass__ + + # Also apply timing to methods in the current class + for method_name in method_names: + if hasattr(cls, method_name): + original_method = getattr(cls, method_name) + if not hasattr(original_method, '_is_timed'): + timed_method = timing_decorator(method_name)( + original_method) + timed_method._is_timed = True + setattr(cls, method_name, timed_method) + + return cls + + return class_decorator + + +def batch_tokenize_prompts( + prompts: list[str], + tokenizer: PreTrainedTokenizerBase, + batch_size: int = 1000, + progress_name: str = "prompts") -> tuple[list[int], list[list[int]]]: + """ + Efficiently tokenize a list of prompts using batch processing. + + Args: + prompts: List of text prompts to tokenize + tokenizer: The tokenizer to use + batch_size: Number of prompts to process in each batch + progress_name: Name to show in progress messages + + Returns: + Tuple of (prompt_lengths, prompt_token_ids) where: + - prompt_lengths: List of prompt lengths (number of tokens per prompt) + - prompt_token_ids: List of token ID lists for each prompt + """ + import time + + if not prompts: + return [], [] + + print( + f"Batch tokenizing {len(prompts)} {progress_name} (batch_size={batch_size})..." + ) + + prompt_lengths = [] + prompt_token_ids = [] + total_time = 0 + + for i in range(0, len(prompts), batch_size): + batch_prompts = prompts[i:i + batch_size] + + # Batch tokenization + start_time = time.perf_counter() + batch_encoded = tokenizer(batch_prompts, + padding=False, + truncation=False) + batch_time = time.perf_counter() - start_time + total_time += batch_time + + # Extract lengths and token IDs + for j in range(len(batch_prompts)): + token_ids = batch_encoded.input_ids[j] + prompt_lengths.append(len(token_ids)) + prompt_token_ids.append(token_ids) + + # Progress reporting + if (i + batch_size) % 5000 == 0 or (i + batch_size) >= len(prompts): + processed = min(i + batch_size, len(prompts)) + avg_time = total_time / processed * 1000 + print( + f" Processed {processed}/{len(prompts)} {progress_name} - Avg: {avg_time:.2f}ms per item" + ) + + avg_time_per_prompt = total_time / len(prompts) * 1000 + print( + f"Batch tokenization completed: {total_time:.4f}s total ({avg_time_per_prompt:.2f}ms per {progress_name[:-1]})" + ) + + return prompt_lengths, prompt_token_ids + + @dataclass class SampleRequest: """ @@ -47,6 +209,7 @@ class SampleRequest: prompt: Union[str, Any] prompt_len: int expected_output_len: int + multi_modal_data: Optional[dict] = None # ----------------------------------------------------------------------------- @@ -54,6 +217,7 @@ class SampleRequest: # ----------------------------------------------------------------------------- +@auto_time_methods("load_data", "sample") class BenchmarkDataset(ABC): DEFAULT_SEED = 0 IS_MULTIMODAL = False @@ -72,11 +236,14 @@ class BenchmarkDataset(ABC): sampling. Defaults to DEFAULT_SEED. """ self.dataset_path = dataset_path + self.data = None # Set the random seed, ensuring that a None value is replaced with the # default seed. self.random_seed = (random_seed if random_seed is not None else self.DEFAULT_SEED) - self.data = None + self.rng = torch.Generator() + self.rng.manual_seed(self.random_seed) + random.seed(self.random_seed) def load_data(self) -> None: """ @@ -123,13 +290,26 @@ class BenchmarkDataset(ABC): requests. num_requests (int): The target number of requests. """ if len(requests) < num_requests: - random.seed(self.random_seed) additional = random.choices(requests, k=num_requests - len(requests)) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) + def apply_multimodal_chat_transformation(self, + prompt: str, + mm_content: Optional[dict] = None + ) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + # ----------------------------------------------------------------------------- # Utility Functions and Global Caches @@ -163,6 +343,50 @@ def is_valid_sequence( or combined_too_long) +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + TypeError: If the input is not a supported type. + """ + if isinstance(image, dict) and "bytes" in image: + image = Image.open(io.BytesIO(image["bytes"])) + if isinstance(image, Image.Image): + image = convert_image_mode(image, "RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise TypeError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- @@ -210,13 +434,16 @@ class RandomDataset(BenchmarkDataset): **kwargs, ) -> list[SampleRequest]: # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range") + if range_ratio >= 1.0: + raise ValueError( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) vocab_size = tokenizer.vocab_size - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + prefix_token_ids = (torch.randint( + 0, vocab_size, size=(prefix_len, ), generator=self.rng).tolist() + if prefix_len > 0 else []) # New sampling logic: [X * (1 - b), X * (1 + b)] input_low = int(input_len * (1 - range_ratio)) @@ -225,17 +452,22 @@ class RandomDataset(BenchmarkDataset): output_high = int(output_len * (1 + range_ratio)) # Add logging for debugging - logger.info("Sampling input_len from [%s, %s]", input_low, input_high) - logger.info("Sampling output_len from [%s, %s]", output_low, - output_high) + logger.debug("Sampling input_len from [%s, %s]", input_low, input_high) + logger.debug("Sampling output_len from [%s, %s]", output_low, + output_high) - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) + input_lens = torch.randint(input_low, + input_high + 1, + size=(num_requests, ), + generator=self.rng).tolist() + output_lens = torch.randint(output_low, + output_high + 1, + size=(num_requests, ), + generator=self.rng).tolist() + offsets = torch.randint(0, + vocab_size, + size=(num_requests, ), + generator=self.rng).tolist() requests = [] if self.sample_from_sharegpt: @@ -256,23 +488,25 @@ class RandomDataset(BenchmarkDataset): # Shuffle the dataset. random.shuffle(dataset) + # Batch tokenize all prompts first for efficiency + prompt_lengths, prompt_token_ids = batch_tokenize_prompts( + dataset, tokenizer, progress_name="random dataset prompts") + # Filter out sequences that are too long or too short requests = [] - for prompt in dataset: + for prompt, initial_prompt_len, cached_token_ids in zip( + dataset, prompt_lengths, prompt_token_ids): i = len(requests) if i == num_requests: break - # Tokenize the prompts and completions. - prompt_token_ids = tokenizer.encode(prompt) - prompt_len = len(prompt_token_ids) - # Skip empty prompt - if prompt_len == 0: + if initial_prompt_len == 0: continue - if prompt_len > input_lens[i]: - input_ids = prompt_token_ids[:input_lens[i]] + if initial_prompt_len > input_lens[i]: + # Use cached token IDs to avoid re-encoding + input_ids = cached_token_ids[:input_lens[i]] else: # Re-calculate the prompt length to exclude special tokens. prompt_len = len( @@ -281,11 +515,12 @@ class RandomDataset(BenchmarkDataset): continue ratio = (input_lens[i] + prompt_len) // prompt_len prompt = " ".join([prompt] * ratio) - prompt_token_ids = tokenizer.encode(prompt) - while len(prompt_token_ids) < input_lens[i]: + prompt_token_ids_for_truncation = tokenizer.encode(prompt) + while len(prompt_token_ids_for_truncation) < input_lens[i]: prompt += " " + prompt - prompt_token_ids = tokenizer.encode(prompt) - input_ids = prompt_token_ids[:input_lens[i]] + prompt_token_ids_for_truncation = tokenizer.encode( + prompt) + input_ids = prompt_token_ids_for_truncation[:input_lens[i]] prompt = prefix_token_ids + input_ids @@ -324,6 +559,131 @@ class RandomDataset(BenchmarkDataset): # ----------------------------------------------------------------------------- +class RandomImageDataset(BenchmarkDataset): + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 0.0 + DEFAULT_INPUT_LEN = 128 + DEFAULT_OUTPUT_LEN = 128 + DEFAULT_WIDTH = 512 + DEFAULT_HEIGHT = 512 + DEFAULT_IMAGE_SIZE = 512 + DEFAULT_NUM_IMAGES = 1 + IS_MULTIMODAL = True + + def __init__( + self, + return_text: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.return_text = return_text + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + width: int = DEFAULT_WIDTH, + height: int = DEFAULT_HEIGHT, + image_size: int = DEFAULT_IMAGE_SIZE, + num_images: int = DEFAULT_NUM_IMAGES, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list[SampleRequest]: + # Enforce range_ratio < 1 + if range_ratio >= 1.0: + raise ValueError( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) + + vocab_size = tokenizer.vocab_size + + prefix_token_ids = (torch.randint( + 0, vocab_size, size=(prefix_len, ), generator=self.rng).tolist() + if prefix_len > 0 else []) + + # New sampling logic: [X * (1 - b), X * (1 + b)] + input_low = int(input_len * (1 - range_ratio)) + input_high = int(input_len * (1 + range_ratio)) + output_low = int(output_len * (1 - range_ratio)) + output_high = int(output_len * (1 + range_ratio)) + + # Add logging for debugging + logger.debug("Sampling input_len from [%s, %s]", input_low, input_high) + logger.debug("Sampling output_len from [%s, %s]", output_low, + output_high) + + input_lens = torch.randint(input_low, + input_high + 1, + size=(num_requests, ), + generator=self.rng).tolist() + output_lens = torch.randint(output_low, + output_high + 1, + size=(num_requests, ), + generator=self.rng).tolist() + offsets = torch.randint(0, + vocab_size, + size=(num_requests, ), + generator=self.rng).tolist() + + # Determine final image dimensions + # When both width/height and image_size are provided, prioritize width/height + final_width = width + final_height = height + + # If width and height are still at default values but image_size is different, use image_size + if (width == self.DEFAULT_WIDTH and height == self.DEFAULT_HEIGHT + and image_size != self.DEFAULT_IMAGE_SIZE): + final_width = image_size + final_height = image_size + logger.info("Using width: %s, height: %s for random image dimensions", + final_width, final_height) + logger.info("Generating %d images per request", num_images) + + sampled_requests = [] + for i in range(num_requests): + # Generate random text prompt + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + prompt = prefix_token_ids + inner_seq + if self.return_text: + prompt = tokenizer.decode(prompt) + total_input_len = prefix_len + int(input_lens[i]) + + # Generate random images (support multiple images per request) + images = [] + for _ in range(num_images): + random_image = torch.randint(0, + 256, + (final_height, final_width, 3), + dtype=torch.uint8, + generator=self.rng).numpy() + pil_image = Image.fromarray(random_image) + images.append(pil_image) + + # Process images for multimodal content + mm_content = [process_image(img) for img in images] + + # Handle multimodal chat transformation + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=mm_content, + )) + + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + class CustomDataset(BenchmarkDataset): """ TensorRT-LLM customized dataset implementation. @@ -358,25 +718,40 @@ class CustomDataset(BenchmarkDataset): with open(self.dataset_path, encoding="utf-8") as f: for line in f: self.data.append(json.loads(line)) - random.seed(self.random_seed) random.shuffle(self.data) def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int) -> list[SampleRequest]: - samples: list = [] - for entry in self.data: - if len(samples) >= num_requests: + """ + Optimized version using batch tokenization for better performance. + """ + # Collect all prompts and metadata + prompts = [] + max_tokens_list = [] + + for i, entry in enumerate(self.data): + if len(prompts) >= num_requests: break prompt = entry["input"]["messages"][1]["content"] - prompt_ids = tokenizer(prompt).input_ids - prompt_len = len(prompt_ids) max_tokens = entry["input"]["max_tokens"] + prompts.append(prompt) + max_tokens_list.append(max_tokens) + + # Use batch tokenization utility + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="custom dataset prompts") + + # Create SampleRequest objects + samples = [] + for prompt, prompt_len, max_tokens in zip(prompts, prompt_lengths, + max_tokens_list): samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=max_tokens, )) + return samples @@ -415,7 +790,6 @@ class ShareGPTDataset(BenchmarkDataset): entry for entry in self.data if "conversations" in entry and len(entry["conversations"]) >= 2 ] - random.seed(self.random_seed) random.shuffle(self.data) def sample( @@ -428,33 +802,47 @@ class ShareGPTDataset(BenchmarkDataset): enable_multimodal_chat: bool = False, **kwargs, ) -> list: - samples: list = [] + if enable_multimodal_chat: + raise NotImplementedError + + # Collect prompts and completions for batch processing + prompts = [] + completions = [] + for entry in self.data: - if len(samples) >= num_requests: + if len(prompts) >= num_requests: break prompt, completion = ( entry["conversations"][0]["value"], entry["conversations"][1]["value"], ) + prompts.append(prompt) + completions.append(completion) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - new_output_len = (len(completion_ids) - if output_len is None else output_len) + # Batch tokenize prompts and completions + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="ShareGPT prompts") + completion_lengths, _ = batch_tokenize_prompts( + completions, tokenizer, progress_name="ShareGPT completions") + + # Filter and create samples + samples: list = [] + for prompt, completion, prompt_len, completion_len in zip( + prompts, completions, prompt_lengths, completion_lengths): + new_output_len = completion_len if output_len is None else output_len if not is_valid_sequence(prompt_len, new_output_len, skip_min_output_len_check=output_len is not None): continue - if enable_multimodal_chat: - raise NotImplementedError + samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, )) + self.maybe_oversample_requests(samples, num_requests) return samples @@ -498,10 +886,11 @@ class SonnetDataset(BenchmarkDataset): return_prompt_formatted: bool = False, **kwargs, ) -> list: - # Calculate average token length for a poem line. - tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) - for tokens in tokenized_lines) / len(tokenized_lines) + # Calculate average token length for poem lines using batch tokenization + line_lengths, _ = batch_tokenize_prompts(self.data, + tokenizer, + progress_name="sonnet lines") + avg_len = sum(line_lengths) / len(line_lengths) # Build the base prompt. base_prompt = "Pick as many lines as you can from these poem lines:\n" @@ -658,34 +1047,47 @@ class ConversationDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, **kwargs) -> list: - # Filter examples with at least 2 conversations + if enable_multimodal_chat: + raise NotImplementedError + + # Filter examples with at least 2 conversations and collect data filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) - sampled_requests = [] + prompts = [] + completions = [] dynamic_output = output_len is None for item in filtered_data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break conv = item["conversations"] prompt, completion = conv[0]["value"], conv[1]["value"] + prompts.append(prompt) + completions.append(completion) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 + # Batch tokenize prompts and completions + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="conversation prompts") + completion_lengths, _ = batch_tokenize_prompts( + completions, tokenizer, progress_name="conversation completions") + + # Filter and create samples + sampled_requests = [] + for prompt, completion, prompt_len, completion_len in zip( + prompts, completions, prompt_lengths, completion_lengths): + current_output_len = completion_len if dynamic_output else output_len + assert isinstance(current_output_len, + int) and current_output_len > 0 if dynamic_output and not is_valid_sequence(prompt_len, completion_len): continue - if enable_multimodal_chat: - raise NotImplementedError + sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, - expected_output_len=output_len, + expected_output_len=current_output_len, )) + self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -717,25 +1119,33 @@ class VisionArenaDataset(HuggingFaceDataset): enable_multimodal_chat: bool = False, **kwargs, ) -> list: + if enable_multimodal_chat: + raise NotImplementedError + output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) - sampled_requests = [] + + # Collect prompts for batch processing + prompts = [] + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break - parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) - if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") prompt = parser_fn(item) + mm_content = process_image(item["images"][0]) prompt_len = len(tokenizer(prompt).input_ids) if enable_multimodal_chat: - raise NotImplementedError + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + multi_modal_data=mm_content, )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -769,12 +1179,22 @@ class InstructCoderDataset(HuggingFaceDataset): **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) - sampled_requests = [] + + # Collect prompts for batch processing + prompts = [] for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break prompt = f"{item['instruction']}:\n{item['input']}" - prompt_len = len(tokenizer(prompt).input_ids) + prompts.append(prompt) + + # Batch tokenize prompts + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="instruct coder prompts") + + # Create samples + sampled_requests = [] + for prompt, prompt_len in zip(prompts, prompt_lengths): sampled_requests.append( SampleRequest( prompt=prompt, @@ -813,22 +1233,31 @@ class MTBenchDataset(HuggingFaceDataset): **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) - sampled_requests = [] + # Collect prompts for batch processing + prompts = [] for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break - prompt = item['turns'][0] + raw_prompt = item['turns'][0] # apply template - prompt = tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - add_generation_prompt=True, - tokenize=False) + formatted_prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": raw_prompt + }], + add_generation_prompt=True, + tokenize=False) + prompts.append(formatted_prompt) - prompt_len = len(tokenizer(prompt).input_ids) + # Batch tokenize prompts + prompt_lengths, _ = batch_tokenize_prompts( + prompts, tokenizer, progress_name="MT-Bench prompts") + + # Create samples + sampled_requests = [] + for prompt, prompt_len in zip(prompts, prompt_lengths): sampled_requests.append( SampleRequest( prompt=prompt, @@ -858,20 +1287,32 @@ class AIMODataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, **kwargs) -> list: - sampled_requests = [] dynamic_output = output_len is None + # Collect prompts and completions for batch processing + prompts = [] + completions = [] for item in self.data: - if len(sampled_requests) >= num_requests: + if len(prompts) >= num_requests: break prompt, completion = item['problem'], item["solution"] + prompts.append(prompt) + completions.append(completion) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 + # Batch tokenize prompts and completions + prompt_lengths, _ = batch_tokenize_prompts(prompts, + tokenizer, + progress_name="AIMO prompts") + completion_lengths, _ = batch_tokenize_prompts( + completions, tokenizer, progress_name="AIMO completions") + + # Filter and create samples + sampled_requests = [] + for prompt, completion, prompt_len, completion_len in zip( + prompts, completions, prompt_lengths, completion_lengths): + current_output_len = completion_len if dynamic_output else output_len + assert isinstance(current_output_len, + int) and current_output_len > 0 if dynamic_output and not is_valid_sequence(prompt_len, completion_len, max_prompt_len=2048, @@ -881,7 +1322,7 @@ class AIMODataset(HuggingFaceDataset): SampleRequest( prompt=prompt, prompt_len=prompt_len, - expected_output_len=output_len, + expected_output_len=current_output_len, )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests diff --git a/tensorrt_llm/serve/scripts/benchmark_serving.py b/tensorrt_llm/serve/scripts/benchmark_serving.py index 1aeb87554d..5408db627e 100644 --- a/tensorrt_llm/serve/scripts/benchmark_serving.py +++ b/tensorrt_llm/serve/scripts/benchmark_serving.py @@ -41,8 +41,8 @@ from tensorrt_llm.serve.scripts.backend_request_func import ( RequestFuncInput, RequestFuncOutput, get_tokenizer) from tensorrt_llm.serve.scripts.benchmark_dataset import ( AIMODataset, BurstGPTDataset, ConversationDataset, CustomDataset, - HuggingFaceDataset, InstructCoderDataset, RandomDataset, SampleRequest, - ShareGPTDataset, SonnetDataset, VisionArenaDataset) + HuggingFaceDataset, InstructCoderDataset, RandomDataset, RandomImageDataset, + SampleRequest, ShareGPTDataset, SonnetDataset, VisionArenaDataset) from tensorrt_llm.serve.scripts.benchmark_utils import ( convert_to_pytorch_benchmark_format, write_to_json) # isort: on @@ -288,10 +288,13 @@ async def benchmark( if not no_test_input: print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = \ + test_prompt, test_prompt_len, test_output_len, test_mm_content = \ input_requests[0].prompt, input_requests[0].prompt_len, \ - input_requests[0].expected_output_len + input_requests[0].expected_output_len, input_requests[0].multi_modal_data + assert test_mm_content is None or isinstance( + test_mm_content, list) and all( + isinstance(item, dict) for item in test_mm_content) test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -302,6 +305,7 @@ async def benchmark( logprobs=logprobs, ignore_eos=ignore_eos, extra_body=extra_body, + multi_modal_content=test_mm_content, ) test_output = await request_func(request_func_input=test_input, @@ -323,15 +327,18 @@ async def benchmark( if profile: print("Starting profiler...") - profile_input = RequestFuncInput(model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - ignore_eos=ignore_eos, - extra_body=extra_body) + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + ignore_eos=ignore_eos, + extra_body=extra_body, + multi_modal_content=test_mm_content, + ) profile_output = await request_func(request_func_input=profile_input, streaming=streaming) if profile_output.success: @@ -379,23 +386,26 @@ async def benchmark( i = 0 async for request in get_request(input_requests, request_rate, burstiness): - prompt, prompt_len, output_len = request.prompt, \ - request.prompt_len, request.expected_output_len + prompt, prompt_len, output_len, mm_content = request.prompt, \ + request.prompt_len, request.expected_output_len, request.multi_modal_data req_model_id, req_model_name = model_id, model_name if lora_modules: req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - ignore_eos=ignore_eos, - extra_body=extra_body) + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + ignore_eos=ignore_eos, + extra_body=extra_body, + multi_modal_content=mm_content, + ) tasks.append( asyncio.create_task( limited_request_func(request_func_input=request_func_input, @@ -581,7 +591,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={k: [results[k]] - for k in metrics}, + for k in metrics if k in results}, extra_info={ k: results[k] for k in results if k not in metrics and k not in ignored_metrics @@ -603,6 +613,9 @@ def main(args: argparse.Namespace): tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_mode = args.tokenizer_mode + if backend == "openai-chat": + args.endpoint = "/v1/chat/completions" + if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" base_url = f"{args.base_url}" @@ -665,6 +678,13 @@ def main(args: argparse.Namespace): f" from one of following: {supported_datasets}. " "Please consider contributing if you would " "like to add support for additional dataset formats.") + if dataset_class.IS_MULTIMODAL and backend not in [ + "openai-chat", + ]: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend." + ) input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -684,41 +704,78 @@ def main(args: argparse.Namespace): ) else: - # For datasets that follow a similar structure, use a mapping. - dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(download_path=args.download_path, - download_timeout=args.download_timeout, - random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(sample_from_sharegpt=not args.random_ids, - return_text=not args.tokenize_on_client, - dataset_path=args.dataset_path, - download_path=args.download_path, - download_timeout=args.download_timeout - ).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - range_ratio=args.random_range_ratio, - ) + + def create_dataset_and_sample(dataset_name: str): + """Factory function to create dataset instance and generate samples.""" + + # Dataset factory mapping with lambda functions for lazy evaluation + dataset_factories = { + "sharegpt": + lambda: ShareGPTDataset(download_path=args.download_path, + download_timeout=args.download_timeout, + random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(sample_from_sharegpt=not args.random_ids, + return_text=not args.tokenize_on_client, + dataset_path=args.dataset_path, + download_path=args.download_path, + download_timeout=args.download_timeout, + random_seed=args.seed).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio), + "random_image": + lambda: RandomImageDataset( + random_seed=args.seed, + return_text=not args.tokenize_on_client, + ).sample(tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + width=args.random_image_width, + height=args.random_image_height, + image_size=args.random_image_size, + num_images=args.random_num_images), + } + + if dataset_name not in dataset_factories: + raise ValueError( + f"Unknown dataset: {dataset_name}. " + f"Available datasets: {list(dataset_factories.keys())}") + + return dataset_factories[dataset_name]() + + # Check multimodal compatibility before creating dataset + dataset_class_mapping = { + "sharegpt": ShareGPTDataset, + "burstgpt": BurstGPTDataset, + "random": RandomDataset, + "random_image": RandomImageDataset, } - try: - input_requests = dataset_mapping[args.dataset_name]() - except KeyError as err: - raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + dataset_class = dataset_class_mapping.get(args.dataset_name) + if dataset_class and dataset_class.IS_MULTIMODAL and backend not in [ + "openai-chat" + ]: + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend." + ) + + # Create dataset and generate samples + input_requests = create_dataset_and_sample(args.dataset_name) goodput_config_dict = check_goodput_args(args) # Collect the sampling parameters. @@ -856,7 +913,8 @@ if __name__ == "__main__": type=str, default="sharegpt", choices=[ - "sharegpt", "burstgpt", "sonnet", "random", "hf", "trtllm_custom" + "sharegpt", "burstgpt", "sonnet", "random", "random_image", "hf", + "trtllm_custom" ], help="Name of the dataset to benchmark on.", ) @@ -1106,6 +1164,32 @@ if __name__ == "__main__": help= "Tokenize on client instead of server. This option only takes effect with random dataset to let the server run exactly the same ISL specified by cli.", ) + random_image_group = parser.add_argument_group( + "random image dataset options") + random_image_group.add_argument( + "--random-image-width", + type=int, + default=512, + help="Width of the image.", + ) + random_image_group.add_argument( + "--random-image-height", + type=int, + default=512, + help="Height of the image.", + ) + random_image_group.add_argument( + "--random-image-size", + type=int, + default=512, + help="Squared size of the image.", + ) + random_image_group.add_argument( + "--random-num-images", + type=int, + default=1, + help="Number of images per request.", + ) hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", diff --git a/tensorrt_llm/tools/importlib_utils.py b/tensorrt_llm/tools/importlib_utils.py new file mode 100644 index 0000000000..281258612e --- /dev/null +++ b/tensorrt_llm/tools/importlib_utils.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from pathlib import Path +from types import ModuleType +from typing import Optional, Union + + +def import_custom_module_from_file( + custom_module_path: Union[str, Path]) -> Optional[ModuleType]: + """Import a custom module from a single file. + + Args: + custom_module_path (Union[str, Path]): The path to the custom module file. + + Returns: + The imported module object. + + Raises: + ImportError: If the module cannot be imported. + """ + if isinstance(custom_module_path, str): + custom_module_path = Path(custom_module_path) + print(f"Importing custom module from file: {custom_module_path}") + + # Import single Python file + module = None + spec = importlib.util.spec_from_file_location(custom_module_path.stem, + str(custom_module_path)) + if spec is not None: + module = importlib.util.module_from_spec(spec) + if spec.loader is not None: + spec.loader.exec_module(module) + print( + f"Successfully imported custom module from file: {custom_module_path}" + ) + else: + raise ImportError( + f"Failed to import custom module from {custom_module_path}") + else: + raise ImportError( + f"Failed to import custom module from {custom_module_path}") + return module + + +def import_custom_module_from_dir( + custom_module_path: Union[str, Path]) -> Optional[ModuleType]: + """Import a custom module from a directory. + + Args: + custom_module_path (Union[str, Path]): The path to the custom module directory. + + Returns: + The imported module object. + + Raises: + ImportError: If the module cannot be imported. + + Note: + This function will add the parent directory of the custom module directory to sys.path. + This is useful for importing modules that are not in the current working directory. + """ + if isinstance(custom_module_path, str): + custom_module_path = Path(custom_module_path) + print(f"Importing custom module from directory: {custom_module_path}") + + # Import directory as a package + # Add the parent directory to sys.path so we can import the package + import sys + parent_dir = str(custom_module_path.parent) + if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + + # Import the package + module = None + package_name = custom_module_path.name + try: + module = importlib.import_module(package_name) + print( + f"Successfully imported custom module from directory: {custom_module_path}" + ) + except ImportError as e: + raise ImportError( + f"Failed to import package {package_name} from {custom_module_path}: {e}" + ) + return module + + +def import_custom_module( + custom_module_path: Union[str, Path]) -> Optional[ModuleType]: + """Import a custom module from a file or directory. + + Args: + custom_module_path (Union[str, Path]): The path to the custom module file or directory. + + Returns: + The imported module object. + + Raises: + ImportError: If the module cannot be imported. + FileNotFoundError: If the custom module path does not exist. + """ + if isinstance(custom_module_path, str): + custom_module_path = Path(custom_module_path) + print(f"Importing custom module from: {custom_module_path}") + + if custom_module_path.exists(): + if custom_module_path.is_file(): + return import_custom_module_from_file(custom_module_path) + elif custom_module_path.is_dir(): + return import_custom_module_from_dir(custom_module_path) + else: + raise FileNotFoundError( + f"Custom module path {custom_module_path} is neither a file nor a directory." + ) + else: + raise FileNotFoundError( + f"Custom module path {custom_module_path} does not exist.") + return None diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index 7fbbb018ec..9a2096852b 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -25,25 +25,25 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, import torch.nn as nn import torch.nn.functional as F from PIL import Image -from safetensors.torch import save_file +from safetensors.torch import load_model, save_file from transformers import CLIPImageProcessor from ..runtime.session import Session def add_multimodal_arguments(parser): - parser.add_argument('--model_type', - type=str, - default=None, - choices=[ - 'blip2', 'llava', 'llava_next', 'llava_onevision', - 'llava_onevision_lmms', 'vila', 'nougat', 'cogvlm', - 'fuyu', 'pix2struct', 'neva', 'kosmos-2', - 'video-neva', 'phi-3-vision', 'phi-4-multimodal', - 'mllama', 'internvl', 'qwen2_vl', - 'internlm-xcomposer2', 'qwen2_audio', 'pixtral' - ], - help="Model type") + parser.add_argument( + '--model_type', + type=str, + default=None, + choices=[ + 'blip2', 'llava', 'llava_next', 'llava_onevision', + 'llava_onevision_lmms', 'vila', 'nougat', 'cogvlm', 'fuyu', + 'pix2struct', 'neva', 'kosmos-2', 'video-neva', 'phi-3-vision', + 'phi-4-multimodal', 'mllama', 'internvl', 'qwen2_vl', + 'internlm-xcomposer2', 'qwen2_audio', 'pixtral', 'eclair' + ], + help="Model type") parser.add_argument( '--model_path', type=str, @@ -144,6 +144,8 @@ class MultimodalEngineBuilder: build_qwen2_audio_engine(args) elif args.model_type == "pixtral": build_pixtral_engine(args) + elif args.model_type == "eclair": + build_eclair_engine(args) else: raise RuntimeError(f"Invalid model type {args.model_type}") @@ -1322,6 +1324,7 @@ def compute_rotary_pos_emb(grid_thw, hf_config, VisionRotaryEmbedding): def build_qwen2_vl_engine(args): + import transformers from qwen_vl_utils import process_vision_info from transformers import AutoProcessor, Qwen2VLForConditionalGeneration from transformers.models.qwen2_vl.configuration_qwen2_vl import \ @@ -1389,8 +1392,15 @@ def build_qwen2_vl_engine(args): class VisionAttentionOpt(VisionAttention): def __init__(self, config: Qwen2VLVisionConfig): - super().__init__(config) - self.head_dim = config.embed_dim // config.num_heads + # Fallback for compatibility with older transformers versions (for certain nvbugs/tests) + if transformers.__version__ >= '4.53.0': + super().__init__(config) + self.head_dim = config.embed_dim // config.num_heads + else: + num_heads = config.num_heads + dim = config.embed_dim + super().__init__(dim, num_heads) + self.head_dim = dim // num_heads def forward(self, hidden_states: torch.Tensor, @@ -1739,3 +1749,78 @@ def build_pixtral_engine(args): max_batch_size=args.max_batch_size, engine_name=f"model.engine", dtype=torch.bfloat16) + + +def build_eclair_engine(args): + + class RadioWithNeck(torch.nn.Module): + + def __init__(self): + super().__init__() + + try: + self.model_encoder = torch.hub.load("NVlabs/RADIO", + "radio_model", + version="radio_v2.5-h") + except Exception as e: + raise RuntimeError( + f"Failed to load RADIO model from torch.hub: {e}") + self.model_encoder.summary_idxs = torch.tensor(4) + + self.conv1 = torch.nn.Conv1d(1280, 1024, 1) + self.layer_norm1 = torch.nn.LayerNorm(1024, + eps=1e-6, + elementwise_affine=True) + self.conv2 = torch.nn.Conv2d(1024, + 1024, + kernel_size=(1, 4), + stride=(1, 4), + padding=0, + bias=False) + self.layer_norm2 = torch.nn.LayerNorm(1024, + eps=1e-6, + elementwise_affine=True) + + @torch.no_grad + def forward(self, pixel_values): + _, feature = self.model_encoder(pixel_values) + output = self.conv1(feature.permute(0, 2, 1)).permute(0, 2, 1) + output = self.layer_norm1(output).permute(0, 2, 1) + + b, d, _ = output.shape + h = pixel_values.shape[-2] // 16 + w = pixel_values.shape[-1] // 16 + output = self.conv2(output.reshape(b, d, h, w)) + output = output.flatten(-2, -1).permute(0, 2, 1) + output = self.layer_norm2(output) + return output + + processor = NougatProcessor.from_pretrained(args.model_path) + model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base") + model.encoder = RadioWithNeck() + model.decoder.resize_token_embeddings(len(processor.tokenizer)) + model.config.decoder_start_token_id = processor.tokenizer.eos_token_id # 2 + model.config.pad_token_id = processor.tokenizer.pad_token_id # 1 + checkpoint_path = os.path.join(args.model_path, "model.safetensors") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Model checkpoint not found at {checkpoint_path}") + load_model(model, checkpoint_path) + + wrapper = model.encoder.to(args.device) + # temporary fix due to TRT onnx export bug + for block in wrapper.model_encoder.model.blocks: + block.attn.fused_attn = False + + image = torch.randn((1, 3, 2048, 1648), + device=args.device, + dtype=torch.bfloat16) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') + build_trt_engine( + args.model_type, + [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] + f'{args.output_dir}/onnx', + args.output_dir, + args.max_batch_size, + dtype=torch.bfloat16, + engine_name='visual_encoder.engine') diff --git a/tensorrt_llm/top_model_mixin.py b/tensorrt_llm/top_model_mixin.py index 61e8dcfa4f..4d3702dca5 100644 --- a/tensorrt_llm/top_model_mixin.py +++ b/tensorrt_llm/top_model_mixin.py @@ -15,7 +15,7 @@ from typing import Optional -from .lora_manager import LoraConfig +from .lora_helper import LoraConfig from .mapping import Mapping from .plugin.plugin import PluginConfig diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index 6feaea3d2e..603fd689b7 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.0.0rc6" +__version__ = "1.1.0rc1" diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index a2ca6317b6..23a7d075d9 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -281,11 +281,11 @@ "disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]": 98.97588296607137, "disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]": 67.9668476767838, "test_unittests.py::test_unittests_v2[unittest/_torch/test_attention_mla.py]": 26.32902159006335, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 591.2785023800097, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]": 306.84709841990843, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 220.57452515885234, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]": 202.22269394202158, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 165.08514453098178, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]": 591.2785023800097, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]": 306.84709841990843, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]": 220.57452515885234, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]": 202.22269394202158, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]": 165.08514453098178, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]": 252.70569713797886, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 85.24235329206567, "test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]": 81.43792725296225, @@ -305,11 +305,11 @@ "test_e2e.py::test_llmapi_load_engine_from_build_command[llama-llama-models/llama-7b-hf]": 200.82293555140495, "test_unittests.py::test_unittests_v2[unittest/trt/model/test_llama.py]": 1494.1103300452232, "test_unittests.py::test_unittests_v2[unittest/trt/attention/test_gpt_attention.py -k \"partition0\"]": 77.31474154582247, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 295.3527018489549, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]": 143.84012729604729, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 107.58471493399702, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]": 205.7252635700861, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 113.82226522010751, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]": 295.3527018489549, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]": 143.84012729604729, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]": 107.58471493399702, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]": 205.7252635700861, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]": 113.82226522010751, "examples/test_llama.py::test_llm_llama_1gpu[llama-3.1-8b-instruct-hf-fp8-enable_fp8-float16-summarization-nb:1]": 853.2910006027669, "test_e2e.py::test_openai_chat_example": 876.1966922096908, "test_e2e.py::test_trtllm_serve_example": 200.09309104084969, diff --git a/tests/integration/defs/accuracy/accuracy_core.py b/tests/integration/defs/accuracy/accuracy_core.py index 2bfb962d35..c72917701f 100644 --- a/tests/integration/defs/accuracy/accuracy_core.py +++ b/tests/integration/defs/accuracy/accuracy_core.py @@ -25,6 +25,7 @@ import yaml import tensorrt_llm.evaluate from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM from tensorrt_llm.builder import BuildConfig from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig @@ -144,7 +145,7 @@ class AccuracyTask: return num_samples, threshold def evaluate(self, - llm: Union[LLM, PyTorchLLM], + llm: Union[LLM, PyTorchLLM, AutoDeployLLM], extra_acc_spec: Optional[str] = None, extra_evaluator_kwargs: Optional[dict] = None, sampling_params: Optional[SamplingParams] = None, @@ -155,6 +156,8 @@ class AccuracyTask: spec_dec_algo = None elif isinstance(llm.args.speculative_config, DecodingBaseConfig): spec_dec_algo = llm.args.speculative_config.decoding_type + if spec_dec_algo == 'AUTO': + spec_dec_algo = 'NGram' else: raise ValueError( f"Not recognized speculative_config: {llm.args.speculative_config}." @@ -192,7 +195,11 @@ class AccuracyTask: evaluator_kwargs.update(extra_evaluator_kwargs) evaluator = self.EVALUATOR_CLS(num_samples=num_samples, **evaluator_kwargs) - accuracy = evaluator.evaluate(llm, sampling_params, streaming) + evaluate_kwargs = {} + if hasattr(self, 'EVALUATE_KWARGS'): + evaluate_kwargs.update(self.EVALUATE_KWARGS) + accuracy = evaluator.evaluate(llm, sampling_params, streaming, + **evaluate_kwargs) if self.HIGHER_IS_BETTER: assert accuracy >= threshold, f"Expected accuracy >= {threshold}, but got {accuracy}." else: @@ -298,6 +305,8 @@ class GSM8K(AccuracyTask): EVALUATOR_CLS = tensorrt_llm.evaluate.GSM8K EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR, random_seed=0) + EVALUATE_KWARGS = dict(scores_filter=None) + class GPQADiamond(AccuracyTask): DATASET = "gpqa_diamond" diff --git a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml index 56b729c7c0..f729cef1bd 100644 --- a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml +++ b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml @@ -3,6 +3,9 @@ meta-llama/Llama-3.3-70B-Instruct: - quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 45.55 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 48.03 - quant_algo: FP8 accuracy: 48.03 deepseek-ai/DeepSeek-R1: diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index f69f02eaeb..26de82cbc0 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -17,6 +17,8 @@ meta-llama/Llama-3.3-70B-Instruct: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 84.08 + - quant_algo: FP8 + accuracy: 84.08 meta-llama/Llama-4-Maverick-17B-128E-Instruct: - accuracy: 92.20 meta-llama/Llama-4-Scout-17B-16E-Instruct: @@ -78,6 +80,7 @@ Qwen3/Qwen3-8B: kv_cache_quant_algo: FP8 accuracy: 87.1114 Qwen3/Qwen3-30B-A3B: + - accuracy: 83.43 - quant_algo: FP8_BLOCK_SCALES accuracy: 84.36 - quant_algo: FP8 @@ -156,5 +159,13 @@ microsoft/Phi-4-multimodal-instruct-long-rope: - accuracy: 75.85 microsoft/Phi-4-mini-instruct: - accuracy: 82.30 +GPT-OSS/BF16: + - accuracy: 90.3 +GPT-OSS/MXFP4: + - accuracy: 90.3 + - quant_algo: W4A8_MXFP4_MXFP8 + accuracy: 90.3 + - quant_algo: W4A8_MXFP4_FP8 + accuracy: 90.3 LGAI-EXAONE/EXAONE-4.0-32B: - accuracy: 88.36 diff --git a/tests/integration/defs/accuracy/references/json_mode_eval.yaml b/tests/integration/defs/accuracy/references/json_mode_eval.yaml index d22461d8aa..f8b82fef8e 100644 --- a/tests/integration/defs/accuracy/references/json_mode_eval.yaml +++ b/tests/integration/defs/accuracy/references/json_mode_eval.yaml @@ -1,2 +1,6 @@ meta-llama/Llama-3.1-8B-Instruct: - accuracy: 74.00 + - spec_dec_algo: Eagle + accuracy: 74.00 + - spec_dec_algo: NGram + accuracy: 74.00 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 485ad7c029..7f2bb55e6f 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -67,6 +67,8 @@ meta-llama/Llama-3.3-70B-Instruct: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 81.02 + - quant_algo: FP8 + accuracy: 80.34 meta-llama/Llama-4-Maverick-17B-128E-Instruct: - accuracy: 86.40 - quant_algo: FP8 @@ -162,6 +164,10 @@ deepseek-ai/DeepSeek-R1: spec_dec_algo: MTP accuracy: 87.573 Qwen3/Qwen3-8B: + - quant_algo: W4A8_MXFP4_FP8 + accuracy: 72.70 + - quant_algo: W4A8_MXFP4_MXFP8 + accuracy: 72.70 - quant_algo: FP8_BLOCK_SCALES accuracy: 76.12 - accuracy: 76.12 @@ -176,6 +182,12 @@ Qwen3/Qwen3-30B-A3B: - quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 80.65 + - quant_algo: W4A8_MXFP4_FP8 + accuracy: 79.78 + - quant_algo: W4A8_MXFP4_MXFP8 + accuracy: 79.78 + - quant_algo: W4A16_MXFP4 + accuracy: 79.80 Qwen3/Qwen3-235B-A22B: - quant_algo: FP8 kv_cache_quant_algo: FP8 @@ -244,3 +256,11 @@ microsoft/Phi-4-multimodal-instruct-long-rope: - accuracy: 65.98 LGAI-EXAONE/EXAONE-4.0-32B: - accuracy: 78.52 +GPT-OSS/BF16: + - accuracy: 77.50 +GPT-OSS/MXFP4: + - accuracy: 75.50 + - quant_algo: W4A8_MXFP4_MXFP8 + accuracy: 75.50 + - quant_algo: W4A8_MXFP4_FP8 + accuracy: 75.50 diff --git a/tests/integration/defs/accuracy/test_cli_flow.py b/tests/integration/defs/accuracy/test_cli_flow.py index 6dad17e8f9..a3175c2cc6 100644 --- a/tests/integration/defs/accuracy/test_cli_flow.py +++ b/tests/integration/defs/accuracy/test_cli_flow.py @@ -958,6 +958,7 @@ class TestLlama3_2_1B(CliFlowAccuracyTestHarness): # TODO: Remove the CLI tests once NIMs use PyTorch backend +@pytest.mark.skip_less_device_memory(80000) class TestLlama3_3_70BInstruct(CliFlowAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct" MODEL_PATH = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct" diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index ab3ffb50f8..51a572ce49 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -4,6 +4,7 @@ # Please take a look at the existing test_llm_api_pytorch.py file for reference. import concurrent import contextlib +import json import os import tempfile import time @@ -19,12 +20,13 @@ import yaml from tensorrt_llm.executor.result import GenerationResultBase from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams from tensorrt_llm.llmapi.llm_args import LlmArgs +from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids, - skip_pre_hopper) + skip_pre_blackwell, skip_pre_hopper) from ..trt_test_alternative import popen -from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness, - get_accuracy_task) +from .accuracy_core import (GSM8K, MMLU, JsonModeEval, + LlmapiAccuracyTestHarness, get_accuracy_task) class Result(GenerationResultBase): @@ -43,7 +45,7 @@ class Result(GenerationResultBase): return self -DuckLLM = namedtuple('DuckLLM', ['args', 'generate_async']) +DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async']) class MyThreadPoolExecutor(ThreadPoolExecutor): @@ -69,7 +71,9 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], ctx_server_config: Dict[str, Any], gen_server_config: Dict[str, Any], model_name: str, - tensor_parallel_size: int = 1): + tensor_parallel_size: int = 1, + ctx_model: str = None, + gen_model: str = None): temp_dir = tempfile.TemporaryDirectory() disaggregated_serving_config_path = os.path.join( temp_dir.name, "disaggregated_serving_config.yaml") @@ -95,9 +99,19 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], trtllm_serve_path = "trtllm-serve" # Common arguments for both servers - common_args = [ + ctx_model = ctx_model or model_name + gen_model = gen_model or model_name + ctx_args = [ trtllm_serve_path, - model_name, + ctx_model, + "--host", + "localhost", + "--backend", + "pytorch", + ] + gen_args = [ + trtllm_serve_path, + gen_model, "--host", "localhost", "--backend", @@ -123,11 +137,11 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1" env_gen["CUDA_VISIBLE_DEVICES"] = ",".join( map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus))) - ctx_server_args = common_args + [ + ctx_server_args = ctx_args + [ "--port", "8001", "--extra_llm_api_options", ctx_server_config_path, f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}" ] - gen_server_args = common_args + [ + gen_server_args = gen_args + [ "--port", "8002", "--extra_llm_api_options", gen_server_config_path, f"--tp_size={gen_tp}", f"--pp_size={gen_pp}" ] @@ -146,7 +160,8 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], disaggregated_serving_config_path, "--server_start_timeout", "3600" ]) as disaggregated_server): - while True: + start_time = time.time() + while time.time() - start_time < 3600: time.sleep(1) try: print("Checking health endpoint") @@ -161,17 +176,35 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], def send_request(prompt: str, sampling_params: SamplingParams, streaming: bool): - response = client.completions.create( - model=model_name, - prompt=prompt, - stream=streaming, - **({ - "max_tokens": sampling_params.max_tokens, - "temperature": sampling_params.temperature, - "top_p": sampling_params.top_p, - "stop": sampling_params.stop, - "seed": sampling_params.seed - } if sampling_params else {})) + kwargs = {} + if sampling_params is not None: + kwargs.update(max_tokens=sampling_params.max_tokens, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + stop=sampling_params.stop, + seed=sampling_params.seed) + if (guided_decoding_params := + sampling_params.guided_decoding) is not None: + extra_body = {} + if (schema := guided_decoding_params.json) is not None: + extra_body.update(response_format={ + "type": "json", + "schema": json.loads(schema) + }) + elif guided_decoding_params.json_object: + extra_body.update( + response_format={"type": "json_object"}) + else: + # TODO: Support other guided decoding types + raise ValueError( + f"Unsupported guided decoding params: {guided_decoding_params}." + ) + kwargs.update(extra_body=extra_body) + + response = client.completions.create(model=model_name, + prompt=prompt, + stream=streaming, + **kwargs) result = Result(id=0, sampling_params=sampling_params, outputs=[ @@ -191,8 +224,10 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], thread_pool.futures.append(future) return future + tokenizer = load_hf_tokenizer(model_name) + try: - yield DuckLLM(args, generate_async) + yield DuckLLM(args, tokenizer, generate_async) finally: ctx_server.terminate() gen_server.terminate() @@ -203,17 +238,21 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], disaggregated_server.wait() -def run_parallel_test(model_name: str, model_path: str, ctx_pp: int, - ctx_tp: int, gen_pp: int, gen_tp: int, - test_set: LlmapiAccuracyTestHarness): +def run_parallel_test(model_name: str, + model_path: str, + ctx_pp: int, + ctx_tp: int, + gen_pp: int, + gen_tp: int, + test_sets: List[LlmapiAccuracyTestHarness], + ctx_model: str = None, + gen_model: str = None): if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count(): pytest.fail( f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test" ) - kv_cache_config = { "free_gpu_memory_fraction": 0.5, - "enable_block_reuse": False } ctx_server_config = { "pipeline_parallel_size": ctx_pp, @@ -221,7 +260,7 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int, "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } gen_server_config = { @@ -230,7 +269,7 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int, "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } disaggregated_server_config = { @@ -247,10 +286,14 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int, } } with launch_disaggregated_llm(disaggregated_server_config, - ctx_server_config, gen_server_config, - model_path) as llm: - task = test_set(model_name) - task.evaluate(llm) + ctx_server_config, + gen_server_config, + model_path, + ctx_model=ctx_model, + gen_model=gen_model) as llm: + for test_set in test_sets: + task = test_set(model_name) + task.evaluate(llm) @pytest.mark.timeout(3600) @@ -265,8 +308,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): gen_server_config = { "disable_overlap_scheduler": disable_overlap_scheduler } - ctx_server_config["cache_transceiver_config"] = {"backend": "default"} - gen_server_config["cache_transceiver_config"] = {"backend": "default"} + ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} + gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} disaggregated_server_config = { "hostname": "localhost", "port": 8000, @@ -305,7 +348,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } gen_server_config = { @@ -313,7 +356,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): "speculative_config": speculative_decoding_config, "kv_cache_config": kv_cache_config, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } disaggregated_server_config = { @@ -356,7 +399,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): "max_num_tokens": 13393 * 2, "max_batch_size": 1, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" }, "cuda_graph_config": None, } @@ -370,7 +413,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): "max_num_tokens": 13393 * 2, "max_batch_size": 16, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" }, "cuda_graph_config": None, } @@ -393,20 +436,113 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip_less_device(2) + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + ctx_server_config = { + "disable_overlap_scheduler": True, + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "DEFAULT" + } + } + gen_server_config = { + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "DEFAULT" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding_with_eagle3(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + speculative_decoding_config = { + "decoding_type": "Eagle", + "max_draft_len": 3, + "speculative_model_dir": + f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B", + "eagle3_one_model": False + } + + ctx_server_config = { + "disable_overlap_scheduler": True, + "speculative_config": speculative_decoding_config, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.8, + }, + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "DEFAULT" + } + } + gen_server_config = { + "disable_overlap_scheduler": True, + "speculative_config": speculative_decoding_config, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.8, + }, + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "DEFAULT" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)], ids=["tp1pp2", "tp2pp1", "tp2pp2"]) @pytest.mark.parametrize("testset", ["GSM8K", "MMLU"]) def test_tp_pp_symmetric(self, tp, pp, testset): + if tp * pp * 2 > get_device_count(): + pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp, - tp, get_accuracy_task(testset)) + tp, [get_accuracy_task(testset)]) @parametrize_with_ids("ctx_pp", [2, 4]) @parametrize_with_ids("gen_tp", [1, 2]) @pytest.mark.parametrize("testset", ["GSM8K", "MMLU"]) def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset): + if ctx_pp * gen_tp * 2 > get_device_count(): + pytest.skip( + f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1, - gen_tp, get_accuracy_task(testset)) + gen_tp, [get_accuracy_task(testset)]) @pytest.mark.skip_less_device_memory(140000) @@ -415,12 +551,13 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct" MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct" + @pytest.mark.skip_less_device(8) @pytest.mark.parametrize("overlap_scheduler", [False, True]) def test_auto_dtype(self, overlap_scheduler): ctx_server_config = {"disable_overlap_scheduler": True} gen_server_config = {"disable_overlap_scheduler": overlap_scheduler} - ctx_server_config["cache_transceiver_config"] = {"backend": "default"} - gen_server_config["cache_transceiver_config"] = {"backend": "default"} + ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} + gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} # Keep this low to avoid warmup OOM in CI ctx_server_config["max_seq_len"] = 8192 gen_server_config["max_seq_len"] = 8192 @@ -453,14 +590,49 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite" MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16" + def test_nixl_backend(self): + ctx_server_config = { + "disable_overlap_scheduler": True, + "cache_transceiver_config": { + "backend": "NIXL" + } + } + gen_server_config = { + "disable_overlap_scheduler": True, + "cache_transceiver_config": { + "backend": "NIXL" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device(8) @parametrize_with_ids("overlap_scheduler", [True, False]) @parametrize_with_ids("mtp_nextn", [0, pytest.param(2, marks=skip_pre_hopper)]) def test_auto_dtype(self, overlap_scheduler, mtp_nextn): ctx_server_config = {"disable_overlap_scheduler": True} gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler} - ctx_server_config["cache_transceiver_config"] = {"backend": "default"} - gen_server_config["cache_transceiver_config"] = {"backend": "default"} + ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} + gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} if mtp_nextn > 0: ctx_server_config["speculative_config"] = { "decoding_type": "MTP", @@ -501,18 +673,22 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): @pytest.mark.parametrize("overlap_scheduler", [False, True]) def test_auto_dtype(self, overlap_scheduler): + pytest.skip( + "Currently we require full kvcache for variable sliding window. " + "This test only transfers the kvcache inside the sliding window.") + ctx_server_config = { "disable_overlap_scheduler": True, "cuda_graph_config": None, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } gen_server_config = { "disable_overlap_scheduler": overlap_scheduler, "cuda_graph_config": None, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } ctx_server_config["kv_cache_config"] = { @@ -550,20 +726,54 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-8B" MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8" + def test_nixl_backend(self): + ctx_server_config = { + "disable_overlap_scheduler": True, + "cache_transceiver_config": { + "backend": "NIXL" + } + } + gen_server_config = { + "disable_overlap_scheduler": True, + "cache_transceiver_config": { + "backend": "NIXL" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.parametrize("overlap_scheduler", [False, True]) def test_auto_dtype(self, overlap_scheduler): ctx_server_config = { "disable_overlap_scheduler": True, "cuda_graph_config": None, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } gen_server_config = { "disable_overlap_scheduler": overlap_scheduler, "cuda_graph_config": None, "cache_transceiver_config": { - "backend": "default" + "backend": "DEFAULT" } } disaggregated_server_config = { @@ -584,3 +794,27 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): self.MODEL_PATH) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + + +@skip_pre_blackwell +@pytest.mark.timeout(3600) +class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness): + FP4_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf" + FP8_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf" + + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("ctx_pp,gen_tp", [(2, 2)], ids=["ctxpp2gentp2"]) + def test_mixed_ctx_gen_model(self, ctx_pp, gen_tp): + ctx_model = self.FP4_MODEL + gen_model = self.FP8_MODEL + return run_parallel_test("Qwen3/Qwen3-30B-A3B", + ctx_model, + ctx_pp=ctx_pp, + ctx_tp=1, + gen_pp=1, + gen_tp=gen_tp, + test_sets=[GSM8K, MMLU], + ctx_model=ctx_model, + gen_model=gen_model) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py new file mode 100644 index 0000000000..da64969337 --- /dev/null +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM +from tensorrt_llm.sampling_params import SamplingParams + +from ..conftest import llm_models_root +from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness + + +class TestLlama3_1_8B(LlmapiAccuracyTestHarness): + MODEL_NAME = "meta-llama/Llama-3.1-8B" + MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B" + + def get_default_kwargs(self): + return { + 'skip_tokenizer_init': False, + 'trust_remote_code': True, + 'kv_cache_config': { + 'enable_block_reuse': False, + }, + 'max_batch_size': 512, + # 131072 is the max seq len for the model + 'max_seq_len': 8192, + # max num tokens is derived in the build_config, which is not used by AutoDeploy llmargs. + # Set it explicitly here to 8192 which is the default in build_config. + 'max_num_tokens': 8192, + 'skip_loading_weights': False, + 'compile_backend': 'torch-opt', + 'free_mem_ratio': 0.7, + 'cuda_graph_batch_sizes': [1, 2, 4, 8, 16, 32, 64, 128, 256] + } + + def get_default_sampling_params(self): + eos_id = -1 + beam_width = 1 + return SamplingParams(end_id=eos_id, + pad_id=eos_id, + n=beam_width, + use_beam_search=beam_width > 1) + + @pytest.mark.skip_less_device_memory(32000) + def test_auto_dtype(self): + kwargs = self.get_default_kwargs() + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_PATH, + tokenizer=self.MODEL_PATH, + **kwargs) as llm: + task = CnnDailymail(self.MODEL_NAME) + task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 210122411f..89483fd262 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -18,11 +18,13 @@ import pytest from defs.conftest import get_sm_version from tensorrt_llm import LLM +from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ + IS_TRITON_KERNELS_AVAILABLE from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig -from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, - KvCacheConfig, MoeConfig, MTPDecodingConfig, - NGramDecodingConfig, SamplingParams, - TorchCompileConfig) +from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig, + EagleDecodingConfig, KvCacheConfig, MoeConfig, + MTPDecodingConfig, NGramDecodingConfig, + SamplingParams, TorchCompileConfig) from tensorrt_llm.quantization import QuantAlgo from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper, @@ -194,8 +196,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): @skip_pre_hopper def test_fp8_llm_sampler(self): model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8" - with LLM(model_path, enable_trtllm_sampler=True, - max_batch_size=256) as llm: + with LLM(model_path, max_batch_size=256) as llm: assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 sampling_params = SamplingParams( @@ -228,7 +229,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): max_beam_width=max_beam_width, max_batch_size=16, max_seq_len=1024, - enable_trtllm_sampler=True, build_config=None) with llm: @@ -249,7 +249,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): enable_padding=True), ) kv_cache_config = KvCacheConfig( - enable_block_reuse=True + enable_block_reuse=True, free_gpu_memory_fraction=0.8 ) # both one-model and two-model supports this feature eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B" @@ -279,7 +279,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): cuda_graph_config=CudaGraphConfig(batch_sizes=[1]), ) - kv_cache_config = KvCacheConfig(enable_block_reuse=False) + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.8) spec_config = NGramDecodingConfig( max_draft_len=4, @@ -302,9 +303,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) def test_guided_decoding(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) - llm = LLM(self.MODEL_PATH, - guided_decoding_backend=backend, - cuda_graph_config=CudaGraphConfig()) + llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend) with llm: task = JsonModeEval(self.MODEL_NAME) task.evaluate(llm) @@ -316,12 +315,69 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) with LLM(self.MODEL_PATH, guided_decoding_backend=backend, - cuda_graph_config=CudaGraphConfig(), tensor_parallel_size=2, pipeline_parallel_size=2) as llm: task = JsonModeEval(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding_with_eagle3(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + cuda_graph_config = CudaGraphConfig(enable_padding=True) + spec_config = EagleDecodingConfig( + max_draft_len=3, + speculative_model_dir= + f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B", + eagle3_one_model=False) + llm = LLM(self.MODEL_PATH, + guided_decoding_backend=backend, + kv_cache_config=kv_cache_config, + cuda_graph_config=cuda_graph_config, + enable_chunked_prefill=True, + speculative_config=spec_config, + disable_overlap_scheduler=True) + with llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_hopper + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding_with_ngram(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + cuda_graph_config = CudaGraphConfig(enable_padding=True) + spec_config = NGramDecodingConfig(max_draft_len=3, + max_matching_ngram_size=3) + llm = LLM(self.MODEL_PATH, + guided_decoding_backend=backend, + kv_cache_config=kv_cache_config, + cuda_graph_config=cuda_graph_config, + enable_chunked_prefill=True, + speculative_config=spec_config, + disable_overlap_scheduler=True) + with llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_hopper + def test_auto_spec_decode(self): + pytorch_config = { + "cuda_graph_config": + CudaGraphConfig(batch_sizes=[1, 32, 64], enable_padding=True) + } + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.5) + spec_config = AutoDecodingConfig() + with LLM(model=self.MODEL_PATH, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + max_batch_size=64) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + class TestLlama3_2_1B(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.2-1B" @@ -367,6 +423,7 @@ class TestLlama3_2_3B(LlmapiAccuracyTestHarness): @pytest.mark.timeout(7200) @pytest.mark.skip_less_host_memory(1000000) +@pytest.mark.skip_less_device_memory(80000) # 1TB is basic requirement for large model tests. CG4 120G only has 800G host memory, and 480G is shared with GPUs. the test will cause the system crash. class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct" @@ -546,11 +603,14 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness): speculative_model_dir=eagle_model_dir) kv_cache_config = KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.75) + torch_compile_config = TorchCompileConfig( + enable_fullgraph=True, + enable_piecewise_cuda_graph=True, + max_num_streams=3) if torch_compile else None pytorch_config = dict( cuda_graph_config=CudaGraphConfig(max_batch_size=8), enable_attention_dp=False, - torch_compile_config=TorchCompileConfig( - enable_fullgraph=torch_compile)) + torch_compile_config=torch_compile_config) with LLM(model_path, kv_cache_config=kv_cache_config, tensor_parallel_size=tp_size, @@ -737,13 +797,16 @@ class TestGemma3_27BInstruct(LlmapiAccuracyTestHarness): kv_cache_config = KvCacheConfig( enable_block_reuse=False, enable_partial_reuse=False, + free_gpu_memory_fraction=0.5, ) # We use FlashInfer as the attention backend for Gemma3 VLM to support custom mask for images. # So, testing with it here. with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config, attn_backend="FLASHINFER", - cuda_graph_config=None) as llm: + cuda_graph_config=None, + max_batch_size=128, + max_seq_len=4096) as llm: task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) task = MMLU(self.MODEL_NAME) @@ -787,6 +850,9 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): task = MMLU(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.skip( + reason= + "Skipped because cyclic kv cache is disabled on the feature branch") def test_auto_dtype_vswa(self): # # NOTE: Test with VSWA kv cache config. # self.kv_cache_config.max_attention_window = [ @@ -864,6 +930,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16" @pytest.mark.skip_less_device_memory(60000) + # Chunked Prefill for MLA can only be enabled on SM100 + @parametrize_with_ids( + "enable_chunked_prefill", + [False, pytest.param(True, marks=skip_pre_blackwell)]) @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler", [(False, False, False), (True, False, False), @@ -873,7 +943,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): @parametrize_with_ids("mtp_nextn", [0, pytest.param(2, marks=skip_pre_hopper)]) def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph, - overlap_scheduler, torch_compile): + overlap_scheduler, torch_compile, enable_chunked_prefill): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, @@ -889,6 +959,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, **pytorch_config, enable_attention_dp=attention_dp, speculative_config=mtp_config) as llm: @@ -983,7 +1054,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @skip_no_hopper + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False]) @parametrize_with_ids( "fp8kv,attention_dp,cuda_graph,overlap_scheduler", @@ -1008,7 +1079,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - use_cuda_graph=cuda_graph, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), ) @@ -1133,7 +1204,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): task.evaluate(llm) @pytest.mark.skip_less_device(4) - @skip_no_hopper + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False]) @parametrize_with_ids( "fp8kv,attention_dp,cuda_graph,overlap_scheduler", @@ -1166,7 +1237,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - use_cuda_graph=cuda_graph, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), ) @@ -1430,7 +1501,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): @parametrize_with_ids( "quant_dtype", [ - pytest.param("none", marks=skip_pre_blackwell), + pytest.param("none", marks=skip_pre_hopper), # pytest.param("fp8", marks=skip_pre_hopper), pytest.param("nvfp4", marks=skip_pre_blackwell) ]) @@ -1704,29 +1775,28 @@ class TestLlama3_3NemotronSuper49Bv1(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(80000) def test_auto_dtype_tp2(self): - with LLM(self.MODEL_PATH, tensor_parallel_size=2) as llm: + with LLM(self.MODEL_PATH, + tensor_parallel_size=2, + max_seq_len=8192, + max_batch_size=64) as llm: + # Run only one eval as maximal BS is not large task = MMLU(self.MODEL_NAME) task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - task = GPQADiamond(self.MODEL_NAME) - task.evaluate(llm, - extra_evaluator_kwargs=dict(apply_chat_template=True)) @skip_pre_hopper @pytest.mark.skip_less_device(2) @pytest.mark.skip_device_not_contain(["H100", "B200"]) def test_fp8_prequantized_tp2(self): model_path = f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1-FP8" - with LLM(model_path, tensor_parallel_size=2) as llm: + with LLM(model_path, + tensor_parallel_size=2, + max_seq_len=8192, + max_batch_size=64) as llm: assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + + # Run only one eval as maximal BS is not large task = MMLU(self.MODEL_NAME) task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - task = GPQADiamond(self.MODEL_NAME) - task.evaluate(llm, - extra_evaluator_kwargs=dict(apply_chat_template=True)) class TestLlama3_1NemotronNano8Bv1(LlmapiAccuracyTestHarness): @@ -1957,7 +2027,8 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): pipeline_parallel_size=pp_size, moe_expert_parallel_size=ep_size, **pytorch_config, - enable_attention_dp=attention_dp) as llm: + enable_attention_dp=attention_dp, + max_batch_size=64) as llm: task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) task = MMLU(self.MODEL_NAME) @@ -1984,7 +2055,6 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): task = MMLU(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip_less_device_memory(140000) ## OOM on 80G H100 @parametrize_with_ids("eagle3_one_model", [True, False]) @parametrize_with_ids("enable_chunked_prefill", [False, True]) def test_eagle3(self, enable_chunked_prefill, eagle3_one_model): @@ -1992,7 +2062,10 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): disable_overlap_scheduler=True, cuda_graph_config=CudaGraphConfig(batch_sizes=[1]), ) - kv_cache_config = KvCacheConfig(enable_block_reuse=False) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + free_gpu_memory_fraction=0.6, + ) eagle_model_dir = f"{llm_models_root()}/Qwen3/qwen3_8b_eagle3" target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B" @@ -2013,6 +2086,29 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): task = MMLU(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", + [(1, 1, 1, False, True, True)], + ids=["latency"]) + @pytest.mark.parametrize("activation_dtype", ["fp8", "mxfp8"]) + def test_w4a8_mxfp4(self, tp_size, pp_size, ep_size, attention_dp, + cuda_graph, overlap_scheduler, activation_dtype): + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + + llm = LLM( + f"{llm_models_root()}/mxfp4-qwen3/saved_models_Qwen3-8B_w4a8_mxfp4_{activation_dtype}_kv_none_hf", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + **pytorch_config, + enable_attention_dp=attention_dp) + with llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-30B-A3B" @@ -2152,6 +2248,72 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON", "TRTLLM"]) + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [ + (1, 1, 1, False, True, True), + (2, 1, 2, False, True, True), + (4, 1, 4, False, True, True), + ], + ids=["latency", "ep2", "ep4"]) + @pytest.mark.parametrize("activation_dtype", ["static_fp8", "mxfp8"], + ids=["fp8", "mxfp8"]) + def test_w4a8_mxfp4(self, moe_backend, tp_size, pp_size, ep_size, + attention_dp, cuda_graph, overlap_scheduler, + activation_dtype): + if moe_backend == "TRITON": + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("TRITON moe backend is not available.") + if get_sm_version() < 90: + pytest.skip("TRITON moe backend requires Hopper or newer.") + if moe_backend in ["CUTLASS", "TRTLLM"] and get_sm_version() < 100: + pytest.skip( + "CUTLASS or TRTLLM moe backend requires Blackwell or newer.") + if activation_dtype == "mxfp8" and moe_backend not in [ + "TRTLLM", "CUTLASS" + ]: + pytest.skip( + "Mxfp8 is only supported for TRTLLM or CUTLASS moe backend.") + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + + llm = LLM( + f"{llm_models_root()}/mxfp4-qwen3/saved_models_Qwen3-30B-A3B_w4a8_mxfp4_{activation_dtype}_kv_none_hf_moeonly", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + **pytorch_config, + enable_attention_dp=attention_dp, + moe_config=MoeConfig(backend=moe_backend)) + with llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_blackwell + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend", + [(1, 1, 1, False, True, True, "TRTLLM")], + ids=["latency-TRTLLM"]) + def test_w4a16_mxfp4(self, tp_size, pp_size, ep_size, attention_dp, + cuda_graph, overlap_scheduler, moe_backend): + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig(backend=moe_backend)) + + llm = LLM( + f"{llm_models_root()}/mxfp4-qwen3/saved_models_Qwen3-30B-A3B_w4a16_mxfp4_kv_none_hf_moeonly", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + enable_attention_dp=attention_dp, + **pytorch_config) + with llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + class TestQwen3_32B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-32B" @@ -2329,6 +2491,110 @@ class TestPhi4MM(LlmapiAccuracyTestHarness): task.evaluate(llm) +class TestGPTOSS(LlmapiAccuracyTestHarness): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5) + + MODEL_PATH = f"{llm_models_root()}/gpt_oss/gpt-oss-120b" + + def update_task_kwargs(self, task): + task.EVALUATOR_KWARGS["fewshot_as_multiturn"] = True + task.EVALUATOR_KWARGS["apply_chat_template"] = True + task.EVALUATE_KWARGS["scores_filter"] = "exact_match,flexible-extract" + task.MAX_OUTPUT_LEN = 8192 + return task + + @pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRTLLM", "TRITON"], + ids=["cutlass", "trtllm", "triton"]) + @pytest.mark.parametrize("cuda_graph,overlap_scheduler", [ + (True, True), + ]) + def test_w4_1gpu(self, moe_backend, cuda_graph, overlap_scheduler): + if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + + llm = LLM(self.MODEL_PATH, + tensor_parallel_size=1, + pipeline_parallel_size=1, + moe_expert_parallel_size=1, + kv_cache_config=self.kv_cache_config, + **pytorch_config, + moe_config=MoeConfig(backend=moe_backend)) + + with llm: + model_name = "GPT-OSS/MXFP4" + task = GSM8K(model_name) + task = self.update_task_kwargs(task) + task.evaluate(llm) + + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRTLLM", "TRITON"]) + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [ + (4, 1, 1, False, True, True), + (4, 1, 4, False, True, True), + (4, 1, 4, True, True, True), + ], + ids=["tp4", "ep4", "dp4"]) + def test_w4_4gpus(self, moe_backend, tp_size, pp_size, ep_size, + attention_dp, cuda_graph, overlap_scheduler): + if moe_backend == "TRITON": + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + + llm = LLM(self.MODEL_PATH, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + kv_cache_config=self.kv_cache_config, + **pytorch_config, + enable_attention_dp=attention_dp, + moe_config=MoeConfig(backend=moe_backend)) + + with llm: + model_name = "GPT-OSS/MXFP4" + task = GSM8K(model_name) + task = self.update_task_kwargs(task) + task.evaluate(llm) + + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [ + (4, 1, 4, True, True, True), + ], + ids=["dp4"]) + def test_w4a16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, + overlap_scheduler, monkeypatch): + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + monkeypatch.setenv("OVERRIDE_QUANT_ALGO", "W4A16_MXFP4") + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + + llm = LLM(self.MODEL_PATH, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + kv_cache_config=self.kv_cache_config, + **pytorch_config, + enable_attention_dp=attention_dp, + moe_backend="TRITON") + with llm: + model_name = "GPT-OSS/BF16" + task = GSM8K(model_name) + task = self.update_task_kwargs(task) + task.evaluate(llm) + + class TestEXAONE4(LlmapiAccuracyTestHarness): MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B" kv_cache_config = KvCacheConfig( diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml index 6db8a0f1a9..d64bac8763 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml @@ -21,7 +21,7 @@ context_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" - "localhost:8002" @@ -35,7 +35,7 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT kv_cache_config: enable_block_reuse: True enable_partial_reuse: False diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml index cc275b98c7..fe15f70085 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml @@ -17,7 +17,7 @@ context_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.1 cache_transceiver_config: - backend: "default" + backend: "DEFAULT" urls: - "localhost:8001" - "localhost:8002" @@ -33,7 +33,7 @@ generation_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.1 cache_transceiver_config: - backend: "default" + backend: "DEFAULT" urls: - "localhost:8003" - "localhost:8004" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml index 86da31c42b..3ad817167e 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml @@ -15,7 +15,7 @@ context_servers: enable_partial_reuse: True event_buffer_max_size: 1024 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -30,6 +30,6 @@ generation_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.05 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml index e76a253c1a..06a4c154b4 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml @@ -15,7 +15,7 @@ context_servers: enable_partial_reuse: True event_buffer_max_size: 1024 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -30,6 +30,6 @@ generation_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.05 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml index 2292fe22aa..28816380fe 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml @@ -18,7 +18,7 @@ context_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -33,6 +33,6 @@ generation_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml index 345a958fa5..b7f3420272 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml @@ -18,7 +18,7 @@ context_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -33,6 +33,6 @@ generation_servers: event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_genpp2.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_genpp2.yaml index e6a9ab14fe..293e3e604a 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_genpp2.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_genpp2.yaml @@ -16,7 +16,7 @@ context_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -31,6 +31,6 @@ generation_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_gentp2.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_gentp2.yaml index 6d4e326168..67f41bc7e5 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_gentp2.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp2_gentp2.yaml @@ -16,7 +16,7 @@ context_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -31,6 +31,6 @@ generation_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp4_genpp4.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp4_genpp4.yaml index 6621c05d49..3571692123 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp4_genpp4.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp4_genpp4.yaml @@ -16,7 +16,7 @@ context_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -31,6 +31,6 @@ generation_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml index 1f63caed57..83f9b3a3e8 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml @@ -10,7 +10,7 @@ context_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -18,6 +18,6 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml index 97c03fbbcb..57eb4ea004 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml @@ -14,7 +14,7 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: true cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -23,6 +23,6 @@ generation_servers: pipeline_parallel_size: 1 enable_attention_dp: false cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml index 25612d4a78..4343850c77 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml @@ -14,7 +14,7 @@ context_servers: enable_attention_dp: true disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -24,6 +24,6 @@ generation_servers: enable_attention_dp: true disable_overlap_scheduler: False cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml index facc460330..837e5df8e3 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml @@ -14,7 +14,7 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: true cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -25,4 +25,4 @@ generation_servers: urls: - "localhost:8002" cache_transceiver_config: - backend: default + backend: DEFAULT diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_genpp2.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_genpp2.yaml index a6e9b0c85d..ce53fd4626 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_genpp2.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_genpp2.yaml @@ -16,7 +16,7 @@ context_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -31,6 +31,6 @@ generation_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml index 729bdf2cf9..1335d63adf 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml @@ -10,7 +10,7 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -18,7 +18,7 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" - "localhost:8003" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml index 388be9d4d6..fa5dffa518 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml @@ -8,7 +8,7 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -16,7 +16,7 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" - "localhost:8003" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml index 1bc2084286..6b22665e9f 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml @@ -10,7 +10,7 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -18,6 +18,6 @@ generation_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml index 28d4c3556e..80a1a3636a 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml @@ -11,7 +11,7 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -20,6 +20,6 @@ generation_servers: pipeline_parallel_size: 1 enable_attention_dp: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml index 0d05bef459..9dfb092151 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml @@ -11,7 +11,7 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: true cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -20,6 +20,6 @@ generation_servers: pipeline_parallel_size: 1 enable_attention_dp: false cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml index fa771b9e30..4b6bc571da 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml @@ -14,7 +14,7 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: true cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -23,7 +23,7 @@ generation_servers: pipeline_parallel_size: 1 enable_attention_dp: false cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml index 9398f7ddd2..26218586f4 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml @@ -11,7 +11,7 @@ context_servers: enable_attention_dp: True disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -21,6 +21,6 @@ generation_servers: enable_attention_dp: True disable_overlap_scheduler: False cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml index f8c04735eb..99034f8a1a 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml @@ -10,7 +10,7 @@ context_servers: enable_attention_dp: true disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -22,6 +22,6 @@ generation_servers: enable_padding: False disable_overlap_scheduler: False cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml index 912178b7f6..4cfe18ebaf 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml @@ -9,7 +9,7 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "mpi" + backend: "MPI" urls: - "localhost:8001" generation_servers: @@ -17,6 +17,6 @@ generation_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "mpi" + backend: "MPI" urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml index e4fd09a1ce..3b1aa8fc0e 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml @@ -9,7 +9,7 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "nixl" + backend: "NIXL" urls: - "localhost:8001" generation_servers: @@ -17,6 +17,6 @@ generation_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "nixl" + backend: "NIXL" urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml index 9ace31717e..4c601fbb86 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml @@ -9,7 +9,7 @@ context_servers: pipeline_parallel_size: 1 disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -20,6 +20,6 @@ generation_servers: enable_padding: False disable_overlap_scheduler: False cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml index b21637529b..d3395938ca 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml @@ -9,7 +9,7 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "ucx" + backend: "UCX" urls: - "localhost:8001" generation_servers: @@ -17,6 +17,6 @@ generation_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "ucx" + backend: "UCX" urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2pp2_gentp2pp2.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2pp2_gentp2pp2.yaml index 2e862eb6be..db62a89cf7 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2pp2_gentp2pp2.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2pp2_gentp2pp2.yaml @@ -16,7 +16,7 @@ context_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -31,6 +31,6 @@ generation_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml index 8b992d210c..56db3df769 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml @@ -16,7 +16,7 @@ context_servers: batch_sizes: [1,3000] disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -34,6 +34,6 @@ generation_servers: batch_sizes: [1,4,8,16,24,32] disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml index 4bd0322c36..5276de524a 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml @@ -10,7 +10,7 @@ context_servers: max_num_tokens: 512 max_batch_size: 256 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -18,6 +18,6 @@ generation_servers: max_num_tokens: 256 max_batch_size: 128 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml index f42ea826c0..92b1383764 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml @@ -14,7 +14,7 @@ generation_servers: enable_block_reuse: False enable_partial_reuse: False cache_transceiver_config: - backend: default + backend: DEFAULT print_iter_log: True urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_bs1.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_bs1.yaml new file mode 100644 index 0000000000..19d1eca714 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_bs1.yaml @@ -0,0 +1,37 @@ +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +hostname: localhost +port: 8000 +backend: "pytorch" +cuda_graph_config: null +free_gpu_memory_fraction: 0.2 +context_servers: + num_instances: 1 + max_batch_size: 1 + max_num_tokens: 3000 + max_seq_len: 4096 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + enable_attention_dp: true + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + disable_overlap_scheduler: True + cache_transceiver_config: + backend: DEFAULT + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + enable_attention_dp: true + max_batch_size: 1 + max_num_tokens: 4096 + max_seq_len: 4096 + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + cache_transceiver_config: + backend: DEFAULT + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml index 6d9fc7d07f..ad706f8bf1 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml @@ -13,7 +13,7 @@ generation_servers: enable_block_reuse: False enable_partial_reuse: False cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" - "localhost:8003" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml index f0766a9c6d..f0593d9ef6 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml @@ -19,7 +19,7 @@ context_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" - "localhost:8002" @@ -38,7 +38,7 @@ generation_servers: enable_partial_reuse: False disable_overlap_scheduler: False cache_transceiver_config: - backend: "default" + backend: "DEFAULT" urls: - "localhost:8003" - "localhost:8004" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml index 31e429c440..27d7ec4ee8 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml @@ -10,7 +10,7 @@ context_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -18,7 +18,7 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml index 2f779f598a..4e3417c732 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml @@ -9,7 +9,7 @@ context_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "default" + backend: "DEFAULT" urls: - "localhost:8001" generation_servers: @@ -17,7 +17,7 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: "default" + backend: "DEFAULT" urls: - "localhost:8002" speculative_config: diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml index 5cdafaed34..55990bbaa6 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml @@ -16,7 +16,7 @@ context_servers: enable_partial_reuse: False disable_overlap_scheduler: True cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -31,6 +31,6 @@ generation_servers: enable_partial_reuse: False disable_overlap_scheduler: False cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml index 885991c886..3eb275c87e 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml @@ -10,7 +10,7 @@ context_servers: kv_cache_config: free_gpu_memory_fraction: 0.2 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8001" generation_servers: @@ -18,6 +18,6 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: - backend: default + backend: DEFAULT urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml index b7ecb48b30..287d1103a4 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml @@ -11,12 +11,12 @@ context_servers: max_seq_len: 4096 tensor_parallel_size: 1 pipeline_parallel_size: 1 - enable_trtllm_sampler: True + sampler_type: "TRTLLMSampler" kv_cache_config: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False cache_transceiver_config: - backend: "default" + backend: "DEFAULT" disable_overlap_scheduler: True urls: - "localhost:8001" @@ -27,12 +27,12 @@ generation_servers: max_batch_size: 256 max_num_tokens: 4096 max_seq_len: 4096 - enable_trtllm_sampler: True + sampler_type: "TRTLLMSampler" kv_cache_config: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False cache_transceiver_config: - backend: "default" + backend: "DEFAULT" disable_overlap_scheduler: False urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index f88215581d..0d86204ecb 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -14,11 +14,14 @@ # limitations under the License. import os +import re import subprocess +import tempfile import pytest -from defs.conftest import skip_arm, skip_no_hopper -from defs.trt_test_alternative import check_call, popen +import yaml +from defs.conftest import llm_models_root, skip_arm, skip_no_hopper +from defs.trt_test_alternative import check_call, check_output, popen from tensorrt_llm.logger import logger @@ -44,6 +47,8 @@ def get_test_config(test_desc, example_dir, test_root): "gen_only": (2, f"{test_configs_root}/disagg_config_gen_only.yaml"), "gen_only_trt_backend": (2, f"{test_configs_root}/disagg_config_gen_only_trt_backend.yaml"), + "gen_only_bs1": + (4, f"{test_configs_root}/disagg_config_gen_only_bs1.yaml"), "4_ranks": (4, f"{test_configs_root}/disagg_config_ctxtp2_gentp1.yaml"), "4_ranks_trt_backend": (4, @@ -384,6 +389,29 @@ def test_disaggregated_benchmark_gen_only_trt_backend( cwd=llm_venv.get_working_directory()) +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], + indirect=True) +def test_disaggregated_genbs1(disaggregated_test_root, + disaggregated_example_root, llm_venv, + llama_model_root): + src_dst_dict = { + llama_model_root: + f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + env = llm_venv._new_env.copy() + env['TRTLLM_DISAGG_BENCHMARK_GEN_ONLY'] = '1' + run_disaggregated_test(disaggregated_example_root, + "gen_only_bs1", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) @@ -647,7 +675,6 @@ def test_disaggregated_ctxpp2_gentp2(disaggregated_test_root, llm_venv, def test_disaggregated_ctxtp2pp2_gentp2pp2(disaggregated_test_root, llm_venv, disaggregated_example_root, llama_model_root): - pytest.skip(f"8 GPU test times out currently, skipping") src_dst_dict = { llama_model_root: f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -668,7 +695,6 @@ def test_disaggregated_ctxtp2pp2_gentp2pp2(disaggregated_test_root, llm_venv, def test_disaggregated_ctxpp4_genpp4(disaggregated_test_root, llm_venv, disaggregated_example_root, llama_model_root): - pytest.skip(f"8 GPU test times out currently, skipping") src_dst_dict = { llama_model_root: f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -776,6 +802,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_ucx(disaggregated_test_root, cwd=llm_venv.get_working_directory()) +@skip_no_hopper @skip_arm @pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'], indirect=True) @@ -1051,3 +1078,227 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_two_mtp( "deepseek_v3_lite_fp8_tp1_two_mtp", env=llm_venv._new_env, cwd=llm_venv.get_working_directory()) + + +@pytest.fixture(scope="module") +def benchmark_root(): + llm_root = os.getenv("LLM_ROOT") + return os.path.join(llm_root, "tensorrt_llm", "serve", "scripts") + + +@pytest.fixture(scope="module") +def shared_gpt_path(): + DEFAULT_LLM_MODEL_ROOT = os.path.join("/scratch.trt_llm_data", "llm-models") + LLM_MODELS_ROOT = os.environ.get("LLM_MODELS_ROOT", DEFAULT_LLM_MODEL_ROOT) + return os.path.join(LLM_MODELS_ROOT, "datasets", + "ShareGPT_V3_unfiltered_cleaned_split.json") + + +@pytest.fixture(scope="function") +def benchmark_model_root(request): + models_root = llm_models_root() + if (request.param == "DeepSeek-V3-Lite-fp8"): + model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "fp8") + elif (request.param == "DeepSeek-V3-Lite-bf16"): + model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "bf16") + elif request.param == "llama-v3-8b-hf": + model_path = os.path.join(models_root, "llama-models-v3", "8B") + elif request.param == "llama-3.1-8b-instruct-hf-fp8": + model_path = os.path.join(models_root, "llama-3.1-model", + "Llama-3.1-8B-Instruct-FP8") + else: + raise ValueError(f"Failed to find the model: {request.param}") + return model_path + + +def run_disaggregated_benchmark(example_dir, + config_file, + benchmark_root, + benchmark_model_root, + shared_gpt_path, + env=None, + cwd=None): + """Run disaggregated test with given configuration.""" + run_env = env.copy() + run_env["UCX_TLS"] = "^ib" + num_rank = 2 + workers_cmd = [ + 'mpirun', '--allow-run-as-root', '--oversubscribe', '-n', + str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', + config_file + ] + + server_start_timeout = 900 + server_cmd = [ + 'trtllm-serve', 'disaggregated', '--server_start_timeout', + str(server_start_timeout), '-c', config_file + ] + try: + with ( # Start workers + open('output_workers.log', 'w') as output_workers, + popen(workers_cmd, + stdout=output_workers, + stderr=subprocess.STDOUT, + env=run_env, + cwd=cwd) as workers_proc, + # Start server + open('output_disagg.log', 'w') as output_disagg, + popen(server_cmd, + stdout=output_disagg, + stderr=subprocess.STDOUT, + env=run_env, + cwd=cwd) as server_proc): + # Ensure the sever has started + client_dir = f"{example_dir}/clients" + client_cmd = [ + 'python3', f'{client_dir}/disagg_client.py', '-c', + f'{example_dir}/disagg_config.yaml', '-p', + f'{client_dir}/prompts.json', '--ignore-eos', + '--server-start-timeout', + str(server_start_timeout) + ] + # Warm up + check_call(client_cmd, + env=env, + poll_procs=[workers_proc, server_proc]) + # Start Benchmark + benchmark_script = os.path.join(benchmark_root, + "benchmark_serving.py") + benchmark_cmd = [ + 'python3', + benchmark_script, + '--model', + benchmark_model_root, + '--tokenizer', + benchmark_model_root, + '--dataset-name', + 'random', + '--dataset-path', + shared_gpt_path, + '--random-input-len', + '256', + '--random-output-len', + '64', + '--random-prefix-len', + '0', + '--num-prompts', + '320', + '--max-concurrency', + '32', + '--host', + 'localhost', + '--port', + '8000', + '--ignore-eos', + '--no-test-input', + '--percentile-metrics', + 'e2el,ttft', + ] + # warm up + check_call(benchmark_cmd, env=env) + output = check_output(benchmark_cmd, env=env) + e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)" + ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)" + e2el_match = re.search(e2el_pattern, output) + ttft_match = re.search(ttft_pattern, output) + if e2el_match and ttft_match: + median_e2el = float(e2el_match.group(1)) + median_ttft = float(ttft_match.group(1)) + return median_e2el, median_ttft + else: + raise ValueError("No benchmark result found") + + except Exception: + # Print outputs on error + logger.error("-------- Workers output --------") + with open('output_workers.log', 'r') as f: + logger.error(f.read()) + + logger.error("-------- Disagg server output --------") + with open('output_disagg.log', 'r') as f: + logger.error(f.read()) + raise + finally: + server_proc.terminate() + workers_proc.terminate() + server_proc.wait() + workers_proc.wait() + + +def get_config_for_benchmark(model_root, backend): + serve_config = { + "model": model_root, + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "max_batch_size": 2, + "max_num_tokens": 384, + "max_seq_len": 384, + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "disable_overlap_scheduler": True, + "cache_transceiver_config": { + "backend": backend, + "max_tokens_in_buffer": 512, + }, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "max_batch_size": 2, + "max_num_tokens": 384, + "max_seq_len": 384, + "cache_transceiver_config": { + "backend": backend, + "max_tokens_in_buffer": 512, + }, + "urls": ["localhost:8002"] + } + } + return serve_config + + +@pytest.mark.parametrize("benchmark_model_root", [ + 'DeepSeek-V3-Lite-fp8', 'DeepSeek-V3-Lite-bf16', 'llama-v3-8b-hf', + 'llama-3.1-8b-instruct-hf-fp8' +], + indirect=True) +def test_disaggregated_benchmark_on_diff_backends( + disaggregated_test_root, disaggregated_example_root, llm_venv, + benchmark_model_root, benchmark_root, shared_gpt_path): + nixl_config = get_config_for_benchmark(benchmark_model_root, "NIXL") + ucx_config = get_config_for_benchmark(benchmark_model_root, "UCX") + temp_dir = tempfile.TemporaryDirectory() + nixl_config_path = os.path.join(temp_dir.name, "nixl_config.yaml") + ucx_config_path = os.path.join(temp_dir.name, "ucx_config.yaml") + with open(nixl_config_path, 'w', encoding='utf-8') as f: + yaml.dump(nixl_config, f) + with open(ucx_config_path, 'w', encoding='utf-8') as f: + yaml.dump(ucx_config, f) + + env = llm_venv._new_env.copy() + nixl_e2el, nixl_ttft = run_disaggregated_benchmark( + disaggregated_example_root, + nixl_config_path, + benchmark_root, + benchmark_model_root, + shared_gpt_path, + env=env, + cwd=llm_venv.get_working_directory()) + ucx_e2el, ucx_ttft = run_disaggregated_benchmark( + disaggregated_example_root, + ucx_config_path, + benchmark_root, + benchmark_model_root, + shared_gpt_path, + env=env, + cwd=llm_venv.get_working_directory()) + print(f"Nixl E2EL: {nixl_e2el} ms, UCX E2EL: {ucx_e2el} ms") + print(f"Nixl TTFT: {nixl_ttft} ms, UCX TTFT: {ucx_ttft} ms") + + assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el + assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft diff --git a/tests/integration/defs/disaggregated/test_disaggregated_etcd.py b/tests/integration/defs/disaggregated/test_disaggregated_etcd.py index 656b9a675d..a495f35faf 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_etcd.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_etcd.py @@ -244,7 +244,7 @@ def create_config_files(config): context_config_content = """pytorch_backend_config: disable_overlap_scheduler: True cache_transceiver_config: - backend: "default" + backend: "DEFAULT" max_tokens_in_buffer: 2048""" with open(CONTEXT_CONFIG_FILE, 'w') as file: @@ -252,7 +252,7 @@ cache_transceiver_config: # Create generation config file generation_config_content = """cache_transceiver_config: - backend: "default" + backend: "DEFAULT" max_tokens_in_buffer: 2048""" with open(GENERATION_CONFIG_FILE, 'w') as file: diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index 55971c3ad0..93611de040 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -131,7 +131,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, kv_cache_configs = [KvCacheConfig(max_tokens=2048 * 8) for _ in range(2)] cache_transceiver_configs = [ - CacheTransceiverConfig(backend="default") for _ in range(2) + CacheTransceiverConfig(backend="DEFAULT") for _ in range(2) ] model_names = [model_path(model) for _ in range(2)] ranks = [0, 1] @@ -274,7 +274,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, for _ in range(2) ] cache_transceiver_configs = [ - CacheTransceiverConfig(backend="default") for _ in range(2) + CacheTransceiverConfig(backend="DEFAULT") for _ in range(2) ] model_names = [model_path(model) for _ in range(2)] ranks = [0, 1] @@ -349,13 +349,15 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, @pytest.mark.parametrize("model", ["Llama-3.1-8B-Instruct"]) @pytest.mark.parametrize("spec_dec_model_path", ["EAGLE3-LLaMA3.1-Instruct-8B"]) @pytest.mark.parametrize("generation_overlap", [False]) +@pytest.mark.parametrize("eagle3_one_model", [True, False]) def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, - generation_overlap): + generation_overlap, + eagle3_one_model): # Test whether the batch slots are properly released when using speculative decoding # with disaggregated serving. spec_dec_config = EagleDecodingConfig( speculative_model_dir=model_path(spec_dec_model_path), - eagle3_one_model=False, + eagle3_one_model=eagle3_one_model, max_draft_len=3) worker_pytorch_configs = [] @@ -377,7 +379,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, for _ in range(2) ] cache_transceiver_configs = [ - CacheTransceiverConfig(backend="default") for _ in range(2) + CacheTransceiverConfig(backend="DEFAULT") for _ in range(2) ] model_names = [model_path(model) for _ in range(2)] ranks = [0, 1] diff --git a/tests/integration/defs/perf/README_release_test.md b/tests/integration/defs/perf/README_release_test.md index 7bff0ed37d..2fe42147c7 100644 --- a/tests/integration/defs/perf/README_release_test.md +++ b/tests/integration/defs/perf/README_release_test.md @@ -111,15 +111,40 @@ if self._config.backend == "pytorch": ### 3.1 Full Test Cycles -1. **trt_llm_release_perf_test.yml** - Release performance test -2. **trt_llm_perf_cluster_test.yml** - Cluster performance test +1. **llm_perf_full.yml** - Release performance test + - [test_lists/qa/llm_perf_full.yml](../../test_lists/qa/llm_perf_full.yml) +2. **llm_perf_cluster.yml** - Cluster performance test(for Blackwell) + - [test_lists/qa/llm_perf_cluster.yml](../../test_lists/qa/llm_perf_cluster.yml) +3. **llm_perf_nim.yml** - NIM performance test + - [test_lists/qa/llm_perf_nim.yml](../../test_lists/qa/llm_perf_nim.yml) ### 3.2 Sanity Test Cycles -- **trt_llm_release_perf_sanity.yml** - Release performance sanity test +- **llm_perf_sanity.yml** - Release performance sanity test + - [test_lists/qa/llm_perf_sanity.yml](../../test_lists/qa/llm_perf_sanity.yml) ## 4. Test Configuration Description +### 4.1 PyTorch Model Configuration + +The default PyTorch configuration is defined in [pytorch_model_config.py](pytorch_model_config.py) and can be overridden for specific test patterns. For example: + +```python +{ + 'patterns': [ + 'qwen3_235b_a22b_fp4-bench-pytorch-float4-maxbs:512-maxnt:2048-input_output_len:1000,2000-con:8-ep:8-gpus:8', + ], + 'config': { + 'enable_attention_dp': False, + 'moe_config': { + 'backend': 'TRTLLM' + } + } +} +``` + +This configuration allows you to customize PyTorch-specific settings for different model patterns while maintaining the base configuration as a fallback. + ### 4.1 Test Case Configuration - Test cases are defined in YAML configuration files - Support for different models, precisions, batch sizes, etc. diff --git a/tests/integration/defs/perf/test_perf.py b/tests/integration/defs/perf/test_perf.py index f443ca1035..b854a54c2b 100644 --- a/tests/integration/defs/perf/test_perf.py +++ b/tests/integration/defs/perf/test_perf.py @@ -127,6 +127,7 @@ MODEL_PATH_DICT = { "phi_4_multimodal_instruct_audio": "multimodals/Phi-4-multimodal-instruct", "bielik_11b_v2.2_instruct": "Bielik-11B-v2.2-Instruct", "bielik_11b_v2.2_instruct_fp8": "Bielik-11B-v2.2-Instruct-FP8", + "mistral_small_v3.1_24b": "Mistral-Small-3.1-24B-Instruct-2503", } # Model PATH of HuggingFace HF_MODEL_PATH = { diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index c9d13f31fc..a46bfd4c91 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -619,7 +619,7 @@ def test_trtllm_bench_invalid_token_pytorch(llm_root, llm_venv, model_name, f"throughput " \ f"--dataset {str(dataset_path)} --backend pytorch " \ f"--extra_llm_api_options {extra_options_path} " \ - f"> {output_path}" + f"> {output_path} 2>&1" # Check clean shutdown (no hang) with pytest.raises(subprocess.CalledProcessError) as exc_info: check_call(benchmark_cmd, shell=True, env=llm_venv._new_env) @@ -629,7 +629,7 @@ def test_trtllm_bench_invalid_token_pytorch(llm_root, llm_venv, model_name, stdout = f.read() # Check that error is reported correctly - assert "Error during benchmarking: Requests failed: Token ID out of range (1 requests)" in stdout + assert "Requests failed: Token ID out of range (1 requests)" in stdout def trtllm_bench_prolog( @@ -1497,6 +1497,13 @@ def test_openai_chat_with_logit_bias(llm_root, llm_venv, sampler: str): ]) +def test_openai_prometheus(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_prometheus.py")]) + + def test_openai_lora(llm_root, llm_venv): test_root = unittest_path() / "llmapi" / "apps" llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")]) @@ -1583,6 +1590,15 @@ def test_build_time_benchmark_sanity(llm_root, llm_venv): ]) +@pytest.mark.skip_less_device_memory(80000) +def test_trtllm_multimodal_benchmark_serving(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd([ + "-m", "pytest", + str(test_root / "_test_trtllm_serve_multimodal_benchmark.py") + ]) + + ### PyTorch examples @@ -2165,8 +2181,8 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, "Describe the scene in the image briefly.", ], "media": [ - [], - [str(test_data_root / "inpaint.png")], + "", + str(test_data_root / "inpaint.png"), ], } } @@ -2174,7 +2190,7 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, expected_keywords = { "NVILA-8B-FP16": { "image": [ - ["stormy", "ocean", "waves", "clouds", "gray", "sky"], + ["stormy", "ocean", "waves", "cloudy", "sunlight", "sky"], ["rock", "formation", "sunny", "sky", "clouds"], ["road", "busy", "car", "black", "blue"], ], @@ -2189,7 +2205,7 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, "llava-v1.6-mistral-7b": { "image": [ ["ocean", "sky", "large", "waves", "shore", "blue"], - ['mountain', 'flat', 'dome', 'formation', 'sky'], + ['mountain', 'flat', 'clouds', 'road', 'sky'], ["highway", "vehicles", "traffic", "bus", "suburban"], ], }, @@ -2206,9 +2222,12 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, }, "qwen2.5-vl-7b-instruct": { "image": [ - ["dramatic", "moody", "ocean", "stormy", "sky", "clouds"], + ["dramatic", "moody", "ocean", "stormy", "sky", "waves"], ["large", "dome", "yosemite", "landmark", "rock", "road"], - ["highway", "traffic", "vehicles", "bus", "police", "traffic"], + [ + "highway", "traffic", "vehicles", "lanes", "congestion", + "road" + ], ], "video": [ ["woman", "neon", "night", "jacket", "wet"], @@ -2226,7 +2245,7 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, ], "mixture_text_image": [["invention", "person", "scientists", "Lick", "engineers"], - ["landscape", "dome", "yosemite", "altitude", "scattered"]] + ["landscape", "trees", "road", "natural", "rock"]] }, "gemma-3-27b-it": { "image": [ @@ -2261,6 +2280,8 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, cmd.append("--image_format=pil") cmd.append("--attention_backend=FLASHINFER") cmd.append("--disable_kv_cache_reuse") + cmd.append("--kv_cache_fraction=0.5") + cmd.append("--max_seq_len=1024") output = llm_venv.run_cmd(cmd, caller=check_output) diff --git a/tests/integration/defs/triton_server/conftest.py b/tests/integration/defs/triton_server/conftest.py index bfdfb4bb4c..2afebbee14 100644 --- a/tests/integration/defs/triton_server/conftest.py +++ b/tests/integration/defs/triton_server/conftest.py @@ -13,6 +13,14 @@ from .trt_test_alternative import (SessionDataWriter, check_call, check_output, print_info) +def find_repo_root(): + """Find the repository root by going up 4 directories from the current file location.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + for _ in range(4): + current_dir = os.path.dirname(current_dir) + return current_dir + + def llm_models_root() -> str: '''return LLM_MODELS_ROOT path if it is set in env, assert when it's set but not a valid path ''' @@ -68,7 +76,11 @@ def output_dir(request): @pytest.fixture(scope="session") def llm_backend_root(): - return os.path.join(os.environ["LLM_ROOT"], "triton_backend") + llm_root = os.environ.get("LLM_ROOT", find_repo_root()) + backend_root = os.path.join(llm_root, "triton_backend") + assert os.path.isabs(backend_root), "LLM backend path must be absolute" + assert os.path.exists(backend_root), f"{backend_root} does not exist" + return backend_root @pytest.fixture(scope="session") diff --git a/tests/integration/defs/triton_server/test_triton_llm.py b/tests/integration/defs/triton_server/test_triton_llm.py index fdf36756d3..d6f4be2b05 100644 --- a/tests/integration/defs/triton_server/test_triton_llm.py +++ b/tests/integration/defs/triton_server/test_triton_llm.py @@ -5,13 +5,14 @@ import pytest import torch import yaml -sys.path.append(os.path.join(os.environ["LLM_ROOT"], "triton_backend")) - from .build_engines import * from .common import * -from .conftest import venv_check_call, venv_check_output +from .conftest import find_repo_root, venv_check_call, venv_check_output from .trt_test_alternative import call, check_call, print_info +LLM_ROOT = os.environ.get("LLM_ROOT", find_repo_root()) +sys.path.append(os.path.join(LLM_ROOT, "triton_backend")) + @pytest.fixture(autouse=True) def stop_triton_server(): @@ -91,7 +92,7 @@ def test_llama_v2_7b_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH = prepare_llama_v2_7b_engine("ifb", tensorrt_llm_llama_example_root, @@ -216,7 +217,7 @@ def test_mistral_v1_7b_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_mistral_v1_7b_engine("ifb", tensorrt_llm_llama_example_root, @@ -333,7 +334,7 @@ def test_mistral_v1_multi_models( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_mistral_v1_7b_engine("ifb", tensorrt_llm_llama_example_root, @@ -402,8 +403,7 @@ def test_mistral_v1_7b_python_backend( tensorrt_llm_llama_example_root, llm_backend_venv, ): - llm_backend_repo_root = os.path.join(os.environ["LLM_ROOT"], - "triton_backend") + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_mistral_v1_7b_engine("python_backend", tensorrt_llm_llama_example_root, @@ -520,7 +520,7 @@ def test_llama_v2_70b_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_llama_v2_70b_engine("ifb", tensorrt_llm_llama_example_root, @@ -640,7 +640,7 @@ def test_llama_v2_70b_ifb_lad( if BATCHING_STRATEGY == "V1" and BATCH_SCHEDULER_POLICY == "max_utilization": pytest.skip("Skipping. V1 doesn't support max_utilization.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_llama_v2_70b_engine("ifb", @@ -765,7 +765,7 @@ def test_medusa_vicuna_7b_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_medusa_vicuna_7b_engine( tensorrt_llm_medusa_example_root, vicuna_7b_model_root, @@ -890,7 +890,7 @@ def test_eagle_vicuna_7b_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_eagle_vicuna_7b_engine( tensorrt_llm_eagle_example_root, vicuna_7b_model_root, @@ -967,7 +967,7 @@ def test_gpt_350m_python_backend( gpt_tokenizer_model_root, llm_backend_venv, ): - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH = prepare_gpt_350m_engine( "python_backend", @@ -1096,7 +1096,7 @@ def test_gpt_350m_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH = prepare_gpt_350m_engine( "ifb", @@ -1232,7 +1232,7 @@ def test_t5_small_enc_dec_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENCODER_ENGINE_DIR, ENGINE_DIR = prepare_t5_small_engine( tensorrt_llm_enc_dec_example_root, t5_small_model_root) @@ -1353,7 +1353,7 @@ def test_whisper_large_v3_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENCODER_ENGINE_DIR, ENGINE_DIR = prepare_whisper_large_engine( tensorrt_llm_whisper_example_root, whisper_large_model_root) @@ -1489,7 +1489,7 @@ def test_gpt_gather_logits_ifb( if BATCHING_STRATEGY == "V1" and BATCH_SCHEDULER_POLICY == "max_utilization": pytest.skip("Skipping. V1 doesn't support max_utilization.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH = prepare_gpt_gather_logits_engine( "ifb", @@ -1616,7 +1616,7 @@ def test_gpt_350m_speculative_decoding( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine CONTROL_ENGINE_DIR = prepare_gpt_350m_engine( "medium_control_ifb", @@ -1806,7 +1806,7 @@ def test_gpt_350m_speculative_decoding_return_logits( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine CONTROL_ENGINE_DIR = prepare_gpt_350m_engine( "medium_control_ifb", @@ -1999,7 +1999,7 @@ def test_gpt_speculative_decoding_bls( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine CONTROL_ENGINE_DIR = prepare_gpt_350m_engine( "medium_control_ifb", @@ -2159,7 +2159,7 @@ def test_llama_v3_speculative_decoding_bls( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine DRAFT_ENGINE_DIR = prepare_llama_v3_8b_engine( tensorrt_llm_example_root, @@ -2324,7 +2324,7 @@ def test_gpt_175b_dummyWeights_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = prepare_gpt_175b_engine("ifb", tensorrt_llm_gpt_example_root, tensorrt_llm_example_root) @@ -2440,7 +2440,7 @@ def test_llava( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH, MULTIMODAL_ENGINE_DIR = prepare_llava_engine( tensorrt_llm_multimodal_example_root, tensorrt_llm_llama_example_root, @@ -2576,7 +2576,7 @@ def test_llava_onevision( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH, MULTIMODAL_ENGINE_DIR = prepare_llava_onevision_engine( tensorrt_llm_multimodal_example_root, tensorrt_llm_qwen_example_root, @@ -2735,7 +2735,7 @@ def test_mllama( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH, MULTIMODAL_ENGINE_DIR = prepare_mllama_engine( tensorrt_llm_multimodal_example_root, tensorrt_llm_mllama_example_root, @@ -2910,7 +2910,7 @@ def test_gpt_next_ptuning_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH, output_model_dir = prepare_gpt_next_ptuning_engine( "ifb", tensorrt_llm_gpt_example_root, gpt_next_ptuning_model_root) @@ -3088,7 +3088,7 @@ def test_gpt_2b_lora_ifb( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine weight_streaming = float(GPU_WEIGHTS_PERCENT) < 1.0 ENGINE_PATH = prepare_gpt_2b_lora_engine("ifb", @@ -3239,7 +3239,7 @@ def test_tiny_llama_1b_guided_decoding( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH, XGRAMMAR_TOKENIZER_INFO_PATH = prepare_tiny_llama_1b_engine( @@ -3388,7 +3388,7 @@ def test_gpt_disaggregated_serving_bls( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH = prepare_gpt_350m_engine( "ifb", @@ -3556,7 +3556,7 @@ def test_benchmark_core_model( llm_backend_venv, ): - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build Engine ENGINE_PATH = model_setup["prepare_engine_fn"]( "ifb", model_setup["example_root"], model_setup["tokenizer_path"]) @@ -3632,7 +3632,7 @@ def test_llmapi_backend(E2E_MODEL_NAME, DECOUPLED_MODE, TRITON_MAX_BATCH_SIZE, TENSOR_PARALLEL_SIZE, llm_backend_inflight_batcher_llm_root, llm_backend_venv, llm_backend_dataset_root): - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") if torch.cuda.device_count() < int(TENSOR_PARALLEL_SIZE): pytest.skip("Skipping. Not enough GPUs.") @@ -3726,6 +3726,12 @@ def test_llmapi_backend(E2E_MODEL_NAME, DECOUPLED_MODE, TRITON_MAX_BATCH_SIZE, output = venv_check_output(llm_backend_venv, run_cmd) assert 'Request is cancelled' in output + # Test request cancellation for non-existing request and completed request + run_cmd = [ + f"{llm_backend_repo_root}/tools/tests/test_llmapi_cancel.py" + ] + output = venv_check_output(llm_backend_venv, run_cmd) + @pytest.mark.parametrize("E2E_MODEL_NAME", ["ensemble", "tensorrt_llm_bls"]) @pytest.mark.parametrize("ACCUMULATE_TOKEN", ["False"]) @@ -3789,7 +3795,7 @@ def test_tiny_llama_ifb_token_counts( if E2E_MODEL_NAME == "ensemble" and ACCUMULATE_TOKEN == "True": pytest.skip("Skipping.") - llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + llm_backend_repo_root = os.path.join(LLM_ROOT, "triton_backend") # Build engine ENGINE_PATH, _ = prepare_tiny_llama_1b_engine( type="ifb", diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index ddd12c4168..1859762fc1 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -448,6 +448,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True] @@ -477,7 +481,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_pref accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4_chunked_prefill[tp4ep4-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_fp8_tp2 accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_nvfp4_tp2 -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] @@ -507,11 +511,36 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] +accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] +accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRITON] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRITON] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-TRITON] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-TRITON] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2] @@ -544,6 +573,8 @@ accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-] test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-] @@ -557,6 +588,8 @@ test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1 test_e2e.py::test_openai_multi_chat_example test_e2e.py::test_openai_consistent_chat test_e2e.py::test_trtllm_benchmark_serving +test_e2e.py::test_trtllm_multimodal_benchmark_serving + llmapi/test_llm_examples.py::test_llmapi_server_example # Pivot to Pytorch test cases. test_e2e.py::test_ptp_quickstart diff --git a/tests/integration/test_lists/qa/llm_function_rtx6kd.txt b/tests/integration/test_lists/qa/llm_function_rtx6kd.txt index fbabac6b84..b3d14c393b 100644 --- a/tests/integration/test_lists/qa/llm_function_rtx6kd.txt +++ b/tests/integration/test_lists/qa/llm_function_rtx6kd.txt @@ -1,16 +1,16 @@ accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2 accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights accuracy/test_cli_flow.py::TestMixtral8x7B::test_nvfp4_prequantized -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] diff --git a/tests/integration/test_lists/qa/llm_function_sanity.txt b/tests/integration/test_lists/qa/llm_function_sanity.txt index 16606a0795..8dc118d991 100644 --- a/tests/integration/test_lists/qa/llm_function_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_sanity.txt @@ -25,6 +25,7 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=FLASHINFER] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=TRTLLM] accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype @@ -36,7 +37,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] @@ -100,8 +101,18 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutl accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRITON] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype +accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] +accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8] @@ -157,3 +168,4 @@ test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1 test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1] test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] test_e2e.py::test_trtllm_benchmark_serving +test_e2e.py::test_trtllm_multimodal_benchmark_serving diff --git a/tests/integration/test_lists/qa/llm_perf_cluster.yml b/tests/integration/test_lists/qa/llm_perf_cluster.yml index 47877a3fcc..878d9129e6 100644 --- a/tests/integration/test_lists/qa/llm_perf_cluster.yml +++ b/tests/integration/test_lists/qa/llm_perf_cluster.yml @@ -1,5 +1,5 @@ version: 0.0.1 -trt_llm_release_perf_cluster_test: +llm_perf_cluster: - condition: ranges: system_gpu_count: @@ -42,6 +42,11 @@ trt_llm_release_perf_cluster_test: - perf/test_perf.py::test_perf[ministral_8b_fp8-bench-pytorch-float8-maxbs:1-input_output_len:500,2000-reqs:8-con:1] - perf/test_perf.py::test_perf[ministral_8b_fp8-bench-pytorch-float8-maxnt:5000-input_output_len:5000,500-reqs:500-con:250] - perf/test_perf.py::test_perf[ministral_8b_fp8-bench-pytorch-float8-input_output_len:500,2000-reqs:500-con:250] + #Mistral-Small-3.1-24B-Instruct-2503 + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-input_output_len:1000,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-input_output_len:1000,2000-reqs:500-con:200] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] TIMEOUT(120) + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:4096-maxnt:20000-input_output_len:20000,2000-reqs:500-con:200] TIMEOUT(120) - condition: @@ -57,6 +62,11 @@ trt_llm_release_perf_cluster_test: - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-streaming-float8-maxbs:256-input_output_len:512,32-gpus:2] - perf/test_perf.py::test_perf[llama_v2_13b-bench-float16-input_output_len:128,128-loras:8-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-quant:fp8-gpus:2] + #Mistral-Small-3.1-24B-Instruct-2503 + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-input_output_len:1000,2000-reqs:8-con:1-gpus:2] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-input_output_len:1000,2000-reqs:500-con:200-gpus:2] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1-gpus:2] TIMEOUT(120) + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:4096-maxnt:20000-input_output_len:20000,2000-reqs:500-con:200-gpus:2] TIMEOUT(120) # Tests for systems with 4+ GPUs - condition: diff --git a/tests/integration/test_lists/qa/llm_perf_full.yml b/tests/integration/test_lists/qa/llm_perf_full.yml index c4778586b5..dfbf59351f 100644 --- a/tests/integration/test_lists/qa/llm_perf_full.yml +++ b/tests/integration/test_lists/qa/llm_perf_full.yml @@ -1,5 +1,5 @@ version: 0.0.1 -trt_llm_release_perf_test: +llm_perf_full: # one gpu test - condition: ranges: @@ -14,29 +14,16 @@ trt_llm_release_perf_test: - '*l20*' - '*h20*' tests: - # E2E BERT - - perf/test_perf.py::test_perf[bert_large-cpp-plugin-float16-bs:32+64-input_len:128+512] - - perf/test_perf.py::test_perf[roberta_base-cpp-plugin-float16-bs:32+64-input_len:128+512] - - # E2E gptManagerBenchmark IFB - # E2E ENC-DEC - - perf/test_perf.py::test_perf[t5_large-cppmanager-exe-plugin_ifb-float16-mp-input_output_len:128,20] - # E2E trtllm-bench #llama_v3.1_8b_instruct - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-streaming-bfloat16-input_output_len:128,128] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-streaming-bfloat16-input_output_len:512,32] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:128,128] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:128,128] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:512,32] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:128,128] - - perf/test_perf.py::test_perf[qwen2_7b_instruct-bench-float16-input_output_len:128,128] + - perf/test_perf.py::test_perf[qwen2_7b_instruct-bench-pytorch-float16-input_output_len:128,128] - perf/test_perf.py::test_perf[starcoder2_3b-bench-pytorch-float16-input_output_len:512,200] - - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:128,128] + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-pytorch-float16-input_output_len:128,128] # Ministral-8B - perf/test_perf.py::test_perf[ministral_8b-bench-pytorch-bfloat16-maxbs:1-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] @@ -47,26 +34,6 @@ trt_llm_release_perf_test: # Ministral-8B LoRA tests (using dummy Mistral LoRA checkpoint) - perf/test_perf.py::test_perf[ministral_8b-bench-pytorch-bfloat16-maxbs:2-maxnt:1024-input_output_len:128,128-loras:1-reqs:8-con:2] - # E2E ENC-DEC - - perf/test_perf.py::test_perf[bart_large_cnn-cppmanager-exe-plugin_ifb-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[mbart_large_50_many_to_one_mmt-cppmanager-exe-plugin_ifb-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[t5-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[flan_t5_base-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[flan_t5_large-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[whisper_large_v3-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[mamba_370m-bench-float16-input_output_len:128,128] - - perf/test_perf.py::test_perf[mamba_370m-bench-float16-input_output_len:512,32] - - perf/test_perf.py::test_perf[mamba_2.8b-bench-float16-input_output_len:128,128] - - perf/test_perf.py::test_perf[mamba_2.8b-bench-float16-input_output_len:512,32] - - # Phi-4-mini-instruct - # cpp - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-con:250] - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-con:250] - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-quant:fp8-con:250] - # reduced 'reqs' to fit timeout limit - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-reqs:8-con:1] - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-quant:fp8-reqs:8-con:1] # Phi-4-multimodal-instruct - perf/test_perf.py::test_perf[phi_4_multimodal_instruct-bench-pytorch-bfloat16-input_output_len:500,2000-con:250] - perf/test_perf.py::test_perf[phi_4_multimodal_instruct-bench-pytorch-bfloat16-input_output_len:1000,1000-con:250] @@ -82,6 +49,11 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-pytorch-bfloat16-input_output_len:2000,500] - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-pytorch-bfloat16-input_output_len:128,128] - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-pytorch-bfloat16-input_output_len:512,32] + #Mistral-Small-3.1-24B-Instruct-2503 + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-input_output_len:1000,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-input_output_len:1000,2000-reqs:500-con:200] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] TIMEOUT(120) + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:4096-maxnt:20000-input_output_len:20000,2000-reqs:500-con:200] TIMEOUT(120) # Test list validation - test_list_validation.py::test_list_validation @@ -98,48 +70,7 @@ trt_llm_release_perf_test: tests: - perf/test_perf.py::test_perf[phi_4_multimodal_instruct_image-bench-pytorch-bfloat16-input_output_len:1000,1000-loras:1-con:250] - perf/test_perf.py::test_perf[phi_4_multimodal_instruct_audio-bench-pytorch-bfloat16-input_output_len:1000,1000-loras:1-con:250] - - perf/test_perf.py::test_perf[llama_v3_8b_instruct-cppmanager-exe-static_batching-plugin_ifb-float16-bs:8+64-input_output_len:128,128+512,32] - - perf/test_perf.py::test_perf[llama_v3_8b_instruct-cppmanager-exe-plugin_ifb-float16-mp-input_output_len:128,128+512,32] - - perf/test_perf.py::test_perf[llama_v3_8b_instruct-cppmanager-exe-plugin_ifb-bfloat16-gwp:0.0-input_output_len:128,128+512,32] - - perf/test_perf.py::test_perf[llama_v3_8b_instruct-cppmanager-exe-plugin_ifb-bfloat16-gwp:0.5-input_output_len:128,128+512,32] - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-reqs:10-con:1] - # Llama-3.1-Nemotron-Nano-8B-v1 - # cpp backend - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-quant:fp8-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-con:250] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-con:250] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-con:250] - # pyt backend - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:500,2000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:1000,1000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:500,2000-reqs:500-con:250] - # FP8 prequantized pyt backend - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:500,2000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:1000,1000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:500-con:250] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:500,2000-reqs:500-con:250] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:1000,1000-reqs:500-con:250] - #long time llama_nemotron cases - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] # timeout for l20, l40s, a100 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:8-con:1] #timeout for l20, l40s, failed for a100 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-quant:fp8-reqs:8-con:1] # timeout for l20, l40s, failed on a100 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-quant:fp8-con:250] # failed for a100 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-con:250] # failed on A100 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-quant:fp8-con:250] # failed on A100 15 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-con:250] # timeout for l20, l40s, a100 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-quant:fp8-con:250] # timeout for l20, l40s, failed on A100 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:500-con:250] # failed for l20, need to extend context token to 5000 for l40s and a100, timeout for h20 - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:1000,1000-reqs:500-con:250] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:20000-input_output_len:20000,2000-reqs:500-con:250] #need to extend context token to 20000 for l40s, timeout for h20, a100 # deepseek_v3_lite_fp8 - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:128,128] - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:2000,500] @@ -158,12 +89,6 @@ trt_llm_release_perf_test: - '*l20*' - '*h20*' tests: - #llama_v3.1_8b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_8b-cppmanager-exe-plugin_ifb-float16-maxbs:256-input_output_len:128,128-beams:4-quant:fp8] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128-quant:w4a16_awq] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128-quant:w4a8_awq] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:128,128-quant:fp8] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:512,32] @@ -171,14 +96,7 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:2000,500] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:500,2000] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-loras:1-reqs:100-con:2-gpus:1] - #mistral_7b_v0.1 - #trt backend - - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-maxbs:256-input_output_len:1000,1000-quant:fp8] - - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-maxbs:256-input_output_len:500,2000-quant:fp8] - #phi_3_mini_4k_instruct - #trt backend - - perf/test_perf.py::test_perf[phi_3_mini_4k_instruct-bench-float16-maxbs:128-input_output_len:1000,1000-quant:fp8] - - perf/test_perf.py::test_perf[phi_3_mini_4k_instruct-bench-float16-maxbs:64-input_output_len:500,2000-quant:fp8] + - perf/test_perf.py::test_perf[bielik_11b_v2.2_instruct_fp8-bench-pytorch-float8-input_output_len:1000,1000-con:250] - perf/test_perf.py::test_perf[bielik_11b_v2.2_instruct_fp8-bench-pytorch-float8-input_output_len:2000,2000-con:250] @@ -188,18 +106,6 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[ministral_8b_fp8-bench-pytorch-float8-maxnt:5000-input_output_len:5000,500-reqs:500-con:250] - perf/test_perf.py::test_perf[ministral_8b_fp8-bench-pytorch-float8-input_output_len:500,2000-reqs:500-con:250] -- condition: - terms: - supports_fp8: true - wildcards: - gpu: - - '*h100*' - - '*h200*' - - '*h20*' - tests: - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:10-con:1] - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:10-con:250] - - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-quant:fp8-reqs:10-con:250] # 2 gpus test - condition: @@ -216,21 +122,13 @@ trt_llm_release_perf_test: - '*h20*' tests: #llama_v3.1_8b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_8b-cppmanager-exe-plugin_ifb-bfloat16-mp-maxbs:256-input_output_len:128,128-pp:2] - #pytorch backend + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-mp-maxbs:256-input_output_len:128,128-pp:2] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:128,128-gpus:2] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-maxbs:256-input_output_len:128,128-gpus:2] #mixtral_8x7b_v0.1 - #trt backend - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-loras:8-gpus:2] #pytorch backend - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-pytorch-float16-input_output_len:128,128-loras:8-gpus:2] #llama_v3.2_1b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-maxnt:5000-input_output_len:5000,500-reqs:10-con:1-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-maxnt:5000-input_output_len:5000,500-reqs:10-con:250-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-gpus:2] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:2000,500-reqs:10-con:1-gpus:2] - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:500,2000-gpus:2] @@ -238,9 +136,11 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:512,32-gpus:2] - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:512,200-gpus:2] - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:500,2000-reqs:10-con:1-gpus:2] - #t5 - - perf/test_perf.py::test_perf[t5-bench-float16-input_output_len:128,20-gpus:2] - - perf/test_perf.py::test_perf[flan_t5_large-bench-float16-input_output_len:128,20-gpus:2] + #Mistral-Small-3.1-24B-Instruct-2503 + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-input_output_len:1000,2000-reqs:8-con:1-gpus:2] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-input_output_len:1000,2000-reqs:500-con:200-gpus:2] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1-gpus:2] TIMEOUT(120) + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:4096-maxnt:20000-input_output_len:20000,2000-reqs:500-con:200-gpus:2] TIMEOUT(120) - condition: ranges: @@ -255,14 +155,6 @@ trt_llm_release_perf_test: - '*a100*' - '*h20*' tests: - #llama_v3.1_70b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:1024,1024-tp:2-gpus:2] - - perf/test_perf.py::test_perf[llama_70b_sq_per_tensor-cppmanager-exe-plugin_ifb-float16-input_output_len:128,128+512,32-gpus:2] - #mixtral_8x7b_v0.1 - #trt backend - - perf/test_perf.py::test_perf[mixtral_8x7b-cppmanager-exe-plugin_ifb-float16-mp-input_output_len:128,128+512,32-gpus:2] - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-gpus:2] #pytorch backend - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-pytorch-float16-input_output_len:128,128-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-pytorch-streaming-float16-input_output_len:128,128-gpus:2] @@ -282,23 +174,13 @@ trt_llm_release_perf_test: - '*l20*' - '*h20*' tests: - #llama_v3.2_1b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:512,32-quant:fp8-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:512,200-quant:fp8-gpus:2] #mixtral_8x7b_v0.1_fp8 pytorch backend - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-gpus:2] - #mistral_7b_v0.1 - #trt backend - - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:1000,1000-quant:fp8-tp:2] - - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:500,2000-quant:fp8-tp:2] #phi_3_mini_128k_instruct - #trt backend - - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-float16-maxbs:128-input_output_len:1000,1000-quant:fp8-tp:2] - - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-float16-maxbs:128-input_output_len:500,2000-quant:fp8-tp:2] + #pytorch backend + - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-pytorch-float16-maxbs:128-input_output_len:1000,1000-quant:fp8-tp:2] + - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-pytorch-float16-maxbs:128-input_output_len:500,2000-quant:fp8-tp:2] - condition: terms: @@ -314,15 +196,10 @@ trt_llm_release_perf_test: - '*h200*' - '*h20*' tests: - #mixtral_8x7b_v0.1 - #trt backend - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-quant:fp8-gpus:2] - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:512,32-quant:fp8-gpus:2] #pytorch backend - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-gpus:2] - #llama_v3.2_1b trt backend - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:500,2000-quant:fp8-con:250-gpus:2] + # 4 gpus test - condition: @@ -338,18 +215,12 @@ trt_llm_release_perf_test: - '*h20*' tests: - - perf/test_perf.py::test_perf[flan_t5_xxl-cppmanager-exe-plugin_ifb-float16-input_output_len:128,128-gpus:4] - - perf/test_perf.py::test_perf[flan_t5_xxl-cppmanager-exe-plugin_ifb-float16-input_output_len:512,32-gpus:4] - - perf/test_perf.py::test_perf[qwen_14b_chat-cppmanager-exe-plugin_ifb-float16-input_output_len:128,128-gpus:4] - - perf/test_perf.py::test_perf[qwen_14b_chat-cppmanager-ootb_except_mha-float16-input_output_len:128,128+512,32-gpus:4] - - perf/test_perf.py::test_perf[starcoder_15.5b-cppmanager-exe-plugin_ifb-float16-maxbs:1-input_output_len:512,200-reqs:10-gpus:4] - - perf/test_perf.py::test_perf[starcoder_15.5b-cppmanager-ootb_except_mha-float16-maxbs:1-input_output_len:512,200-reqs:10-gpus:4] #llama_v3.1_70b #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-streaming-bfloat16-input_output_len:128,128-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-streaming-bfloat16-input_output_len:512,32-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:128,128-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:512,32-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-streaming-bfloat16-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-streaming-bfloat16-input_output_len:512,32-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-input_output_len:512,32-gpus:4] # FP8 specific tests - condition: @@ -365,10 +236,6 @@ trt_llm_release_perf_test: - '*l40s*' - '*h20*' tests: - #llama_v3.1_70b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:512,200-quant:fp8-tp:4] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-input_output_len:128,128-tp:4] #llama_v3.3_70b_instruct_fp8 #pytorch backend - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:500,2000-gpus:4] @@ -376,26 +243,6 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:2000,500-gpus:4] - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:4] - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-gpus:4] - # Llama-Nemotron-Super-49B-v3.3 - # cpp - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:500,2000-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-con:250-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-quant:fp8-con:250-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-input_output_len:500,2000-con:250-gpus:4] - # pyt - # bfloat16 - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:1-maxnt:5000-input_output_len:5000,500-reqs:8-con:1-tp:4-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:1-input_output_len:500,2000-reqs:8-con:1-tp:4-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:256-maxnt:5000-input_output_len:5000,500-reqs:250-con:250-tp:4-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:4-gpus:4] - # fp8 prequantized - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:1-maxnt:5000-input_output_len:5000,500-reqs:8-con:1-tp:4-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:1-input_output_len:500,2000-reqs:8-con:1-tp:4-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:256-maxnt:5000-input_output_len:5000,500-reqs:250-con:250-tp:4-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:4-gpus:4] - condition: @@ -411,15 +258,7 @@ trt_llm_release_perf_test: - '*a100*' - '*h20*' tests: - # E2E trtllm-bench #llama_v3.1_70b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-cppmanager-exe-plugin_ifb-float16-input_output_len:200,2000-reqs:64-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:200,2000-reqs:64-con:200-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:200,2000-reqs:8-con:1-gpus:8] # timeout for h20, move to l2 test - - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-input_output_len:128,128-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-reqs:64-con:250-gpus:8] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-pytorch-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-pytorch-streaming-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] @@ -427,8 +266,6 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.3_70b-bench-pytorch-bfloat16-input_output_len:2000,500-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-pytorch-bfloat16-input_output_len:200,2000-reqs:64-con:200-gpus:8] - - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:8-input_output_len:128,128-reqs:80-gpus:8] - - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:8-input_output_len:512,32-reqs:80-gpus:8] - condition: ranges: @@ -444,25 +281,10 @@ trt_llm_release_perf_test: tests: # E2E trtllm-bench #mixtral_8x7b_v0.1_instruct - #trt backend - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-float16-input_output_len:128,128-reqs:64-gpus:8] # timeout for a100 - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-float16-input_output_len:128,128-reqs:10-con:50-gpus:8] # timeout for a100 - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-float16-input_output_len:128,128-reqs:10-con:1-gpus:8] # timeout for a100 #pytorch backend - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-pytorch-float16-input_output_len:128,128-reqs:64-gpus:8] # timeout for a100 - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-pytorch-float16-input_output_len:128,128-reqs:10-con:50-gpus:8] # timeout for a100 - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-pytorch-float16-input_output_len:128,128-reqs:10-con:1-gpus:8] # timeout for a100 - # Llama-3_1-Nemotron-Ultra-253B-v1 - # all cpp backend, bf16->fp8 post-quantized - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:5000,500-quant:fp8-reqs:8-con:1-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:8-con:1-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:5000,500-quant:fp8-reqs:250-con:250-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:250-con:250-tp:8-gpus:8] - # pyt backend, fp8 pre-quantized - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:1-maxnt:5000-input_output_len:5000,500-reqs:8-con:1-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:1-input_output_len:500,2000-reqs:8-con:1-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:256-maxnt:5000-input_output_len:5000,500-reqs:250-con:250-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:8-gpus:8] # llama_v3.1_405b_fp8 #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-maxbs:1-input_output_len:2000,500-reqs:8-con:1-tp:8-gpus:8] @@ -494,34 +316,25 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-kv_frac:0.85-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-kv_frac:0.85-input_output_len:512,32-ep:8-tp:8-gpus:8] - #deepseek_r1_fp8 + #llama_v4_scout_17b_16e_instruct #pytorch backend - - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-input_output_len:1000,1000-reqs:3000-ep:8-tp:8-gpus:8] - - -- condition: - ranges: - system_gpu_count: - gte: 8 - gpu_memory: - gt: 80000 - wildcards: - gpu: - - '*h200*' - - '*h20*' - tests: - - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:32-input_output_len:128,128-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-streaming-float8-maxbs:32-input_output_len:128,128-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1-input_output_len:1000,2000-reqs:10-con:1-ep:4-tp:8-gpus:8] TIMEOUT(40)#min latency test - - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:128-maxnt:1127-input_output_len:1000,2000-reqs:5120-con:1024-ep:8-tp:8-gpus:8] TIMEOUT(80) #max throughput test - - perf/test_perf.py::test_perf[deepseek_r1_0528_fp8-bench-pytorch-float8-input_output_len:1000,1000-reqs:20000-ep:8-tp:8-gpus:8] TIMEOUT(120) - - perf/test_perf.py::test_perf[deepseek_r1_0528_fp8-bench-pytorch-float8-input_output_len:1000,2000-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(100) - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-streaming-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:500,2000-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:2000,500-ep:8-tp:8-gpus:8] + #deepseek_r1_fp8 + #pytorch backend + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:32-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-streaming-float8-maxbs:32-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-input_output_len:1000,1000-reqs:3000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1-input_output_len:1000,2000-reqs:10-con:1-ep:4-tp:8-gpus:8] #min latency test + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:128-maxnt:1127-input_output_len:1000,2000-reqs:5120-con:1024-ep:8-tp:8-gpus:8] TIMEOUT(80) #max throughput test + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp8-bench-pytorch-float8-input_output_len:1000,1000-reqs:20000-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp8-bench-pytorch-float8-input_output_len:1000,2000-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(100) + # qwen3_235b_a22b_fp8 - perf/test_perf.py::test_perf[qwen3_235b_a22b_fp8-bench-pytorch-float8-input_output_len:1000,2000-con:256-ep:8-gpus:8] TIMEOUT(45) + # FP8 specific tests - condition: terms: @@ -537,14 +350,6 @@ trt_llm_release_perf_test: - '*h20*' tests: #llama_v3.3_70b_instruct_fp8 - #trt backend - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-input_output_len:128,128-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-input_output_len:512,32-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-streaming-float8-maxbs:16-input_output_len:512,32-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:8-con:1-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:64-con:250-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-input_output_len:500,2000-quant:fp8-reqs:8-con:1-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-input_output_len:500,2000-quant:fp8-reqs:64-con:250-gpus:8] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-streaming-float8-input_output_len:512,32-gpus:8] - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-gpus:8] @@ -553,7 +358,6 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct_fp8-bench-pytorch-float8-input_output_len:2000,200-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct_fp8-bench-pytorch-float8-input_output_len:1000,1000-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-input_output_len:128,128-gpus:8] # GB chip specific tests diff --git a/tests/integration/test_lists/qa/llm_perf_nim.yml b/tests/integration/test_lists/qa/llm_perf_nim.yml index d89e854378..42d48f261c 100644 --- a/tests/integration/test_lists/qa/llm_perf_nim.yml +++ b/tests/integration/test_lists/qa/llm_perf_nim.yml @@ -1,5 +1,316 @@ version: 0.0.1 -trt_llm_release_perf_l2_test: +llm_perf_nim: +# one gpu test +- condition: + ranges: + system_gpu_count: + gte: 1 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*a100*' + - '*l40s*' + - '*l20*' + - '*h20*' + tests: + # E2E trtllm-bench + #llama_v3.1_8b_instruct + #trt backend + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-streaming-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-streaming-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32] + # Mistral-7B + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:128,128] + # Phi-4-mini-instruct + # cpp + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-con:250] + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-con:250] + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-quant:fp8-con:250] + # reduced 'reqs' to fit timeout limit + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-quant:fp8-reqs:8-con:1] + +- condition: + ranges: + system_gpu_count: + gte: 1 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*h20*' + tests: + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-reqs:10-con:1] + # Llama-3.1-Nemotron-Nano-8B-v1 + # cpp backend + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-quant:fp8-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-con:250] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-con:250] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-con:250] + # pyt backend + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:500,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:1000,1000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:500,2000-reqs:500-con:250] + # FP8 prequantized pyt backend + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:500,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:1000,1000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:500-con:250] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:500,2000-reqs:500-con:250] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:1000,1000-reqs:500-con:250] + #long time llama_nemotron cases + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] # timeout for l20, l40s, a100 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:8-con:1] #timeout for l20, l40s, failed for a100 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-quant:fp8-reqs:8-con:1] # timeout for l20, l40s, failed on a100 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-quant:fp8-con:250] # failed for a100 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-con:250] # failed on A100 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-input_output_len:1000,1000-quant:fp8-con:250] # failed on A100 15 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-con:250] # timeout for l20, l40s, a100 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-bfloat16-maxbs:64-maxnt:20000-input_output_len:20000,2000-quant:fp8-con:250] # timeout for l20, l40s, failed on A100 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:500-con:250] # failed for l20, need to extend context token to 5000 for l40s and a100, timeout for h20 + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:1000,1000-reqs:500-con:250] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-maxnt:20000-input_output_len:20000,2000-reqs:500-con:250] #need to extend context token to 20000 for l40s, timeout for h20, a100 + +# FP8 specific tests +- condition: + terms: + supports_fp8: true + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*l40s*' + - '*l20*' + - '*h20*' + - '*b200*' + - '*gb200*' + tests: + #llama_v3.1_8b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.1_8b-cppmanager-exe-plugin_ifb-float16-maxbs:256-input_output_len:128,128-beams:4-quant:fp8] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128-quant:w4a16_awq] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128-quant:w4a8_awq] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:128,128-quant:fp8] + #mistral_7b_v0.1 + #trt backend + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-maxbs:256-input_output_len:1000,1000-quant:fp8] + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-maxbs:256-input_output_len:500,2000-quant:fp8] + #phi_3_mini_4k_instruct + #trt backend + - perf/test_perf.py::test_perf[phi_3_mini_4k_instruct-bench-float16-maxbs:128-input_output_len:1000,1000-quant:fp8] + - perf/test_perf.py::test_perf[phi_3_mini_4k_instruct-bench-float16-maxbs:64-input_output_len:500,2000-quant:fp8] + +- condition: + terms: + supports_fp8: true + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*h20*' + tests: + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:10-con:1] + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:10-con:250] + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-bfloat16-maxbs:32-input_output_len:500,2000-quant:fp8-reqs:10-con:250] + + +# 2 gpus test +- condition: + ranges: + system_gpu_count: + gte: 2 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*a100*' + - '*l40s*' + - '*l20*' + - '*h20*' + tests: + #llama_v3.1_8b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.1_8b-cppmanager-exe-plugin_ifb-bfloat16-mp-maxbs:256-input_output_len:128,128-pp:2] + #mixtral_8x7b_v0.1 + #trt backend + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-loras:8-gpus:2] + #llama_v3.2_1b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-maxnt:5000-input_output_len:5000,500-reqs:10-con:1-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-maxnt:5000-input_output_len:5000,500-reqs:10-con:250-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-gpus:2] + #t5 + - perf/test_perf.py::test_perf[t5-bench-float16-input_output_len:128,20-gpus:2] + - perf/test_perf.py::test_perf[flan_t5_large-bench-float16-input_output_len:128,20-gpus:2] + +- condition: + ranges: + system_gpu_count: + gte: 2 + gpu_memory: + gt: 80000 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*a100*' + - '*h20*' + tests: + #llama_v3.1_70b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:1024,1024-tp:2-gpus:2] + - perf/test_perf.py::test_perf[llama_70b_sq_per_tensor-cppmanager-exe-plugin_ifb-float16-input_output_len:128,128+512,32-gpus:2] + #mixtral_8x7b_v0.1 + #trt backend + - perf/test_perf.py::test_perf[mixtral_8x7b-cppmanager-exe-plugin_ifb-float16-mp-input_output_len:128,128+512,32-gpus:2] + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-gpus:2] + +# FP8 specific tests +- condition: + terms: + supports_fp8: true + ranges: + system_gpu_count: + gte: 2 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*l40s*' + - '*l20*' + - '*h20*' + tests: + #llama_v3.2_1b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:512,32-quant:fp8-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:512,200-quant:fp8-gpus:2] + #mistral_7b_v0.1 + #trt backend + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:1000,1000-quant:fp8-tp:2] + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:500,2000-quant:fp8-tp:2] + #phi_3_mini_128k_instruct + #trt backend + - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-float16-maxbs:128-input_output_len:1000,1000-quant:fp8-tp:2] + - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-float16-maxbs:128-input_output_len:500,2000-quant:fp8-tp:2] + +- condition: + terms: + supports_fp8: true + ranges: + system_gpu_count: + gte: 2 + gpu_memory: + gt: 80000 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*h20*' + tests: + #mixtral_8x7b_v0.1 + #trt backend + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-quant:fp8-gpus:2] + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:512,32-quant:fp8-gpus:2] + #llama_v3.2_1b trt backend + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:500,2000-quant:fp8-con:250-gpus:2] + +# 4 gpus test +- condition: + ranges: + system_gpu_count: + gte: 4 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*a100*' + - '*l40s*' + - '*h20*' + tests: + - perf/test_perf.py::test_perf[starcoder_15.5b-cppmanager-exe-plugin_ifb-float16-maxbs:1-input_output_len:512,200-reqs:10-gpus:4] + - perf/test_perf.py::test_perf[starcoder_15.5b-cppmanager-ootb_except_mha-float16-maxbs:1-input_output_len:512,200-reqs:10-gpus:4] + #llama_v3.1_70b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-streaming-bfloat16-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-streaming-bfloat16-input_output_len:512,32-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:512,32-gpus:4] + +# FP8 specific tests +- condition: + terms: + supports_fp8: true + ranges: + system_gpu_count: + gte: 4 + wildcards: + gpu: + - '*b200*' + - '*gb200*' + - '*h100*' + - '*h200*' + - '*l40s*' + - '*h20*' + tests: + #llama_v3.1_70b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:512,200-quant:fp8-tp:4] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-input_output_len:128,128-tp:4] + # Llama-Nemotron-Super-49B-v3.3 + # cpp + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-reqs:4-con:1-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:4-con:1-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:500,2000-reqs:4-con:1-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:4-con:1-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-con:250-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-quant:fp8-con:250-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-input_output_len:500,2000-con:250-gpus:4] + # pyt + # bfloat16 + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:1-maxnt:5000-input_output_len:5000,500-reqs:8-con:1-tp:4-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:1-input_output_len:500,2000-reqs:8-con:1-tp:4-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:256-maxnt:5000-input_output_len:5000,500-reqs:250-con:250-tp:4-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:4-gpus:4] + # fp8 prequantized + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:1-maxnt:5000-input_output_len:5000,500-reqs:8-con:1-tp:4-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:1-input_output_len:500,2000-reqs:8-con:1-tp:4-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:256-maxnt:5000-input_output_len:5000,500-reqs:250-con:250-tp:4-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:4-gpus:4] + +- condition: + ranges: + system_gpu_count: + gte: 8 + gpu_memory: + gt: 80000 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*a100*' + - '*h20*' + tests: + # E2E trtllm-bench + #llama_v3.1_70b + #trt backend + - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-cppmanager-exe-plugin_ifb-float16-input_output_len:200,2000-reqs:64-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:200,2000-reqs:64-con:200-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:200,2000-reqs:8-con:1-gpus:8] # timeout for h20, move to l2 test + - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-input_output_len:128,128-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-reqs:64-con:250-gpus:8] + - condition: ranges: system_gpu_count: @@ -8,9 +319,27 @@ trt_llm_release_perf_l2_test: gt: 100000 wildcards: gpu: + - '*h100*' - '*h200*' - '*h20*' tests: + #mixtral_8x7b_v0.1_instruct + #trt backend + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-float16-input_output_len:128,128-reqs:64-gpus:8] # timeout for a100 + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-float16-input_output_len:128,128-reqs:10-con:50-gpus:8] # timeout for a100 + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-float16-input_output_len:128,128-reqs:10-con:1-gpus:8] # timeout for a100 + # Llama-3_1-Nemotron-Ultra-253B-v1 + # all cpp backend, bf16->fp8 post-quantized + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:5000,500-quant:fp8-reqs:8-con:1-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:8-con:1-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:5000,500-quant:fp8-reqs:250-con:250-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:250-con:250-tp:8-gpus:8] + # pyt backend, fp8 pre-quantized + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:1-maxnt:5000-input_output_len:5000,500-reqs:8-con:1-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:1-input_output_len:500,2000-reqs:8-con:1-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:256-maxnt:5000-input_output_len:5000,500-reqs:250-con:250-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:8-gpus:8] + #deepseek_r1_fp8 - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1-input_output_len:1000,2000-reqs:10-con:1-ep:4-tp:8-gpus:8] #min latency test - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:128-maxnt:1127-input_output_len:1000,2000-reqs:5120-con:1024-ep:8-tp:8-gpus:8] #max throughput test - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:20000-input_output_len:20000,2000-reqs:500-con:250] @@ -30,14 +359,28 @@ trt_llm_release_perf_l2_test: - '*h20*' tests: - perf/test_perf.py::test_perf[mixtral_8x22b_v0.1-bench-float16-input_output_len:512,512-quant:fp8-tp:4] # timeout for h100 - # Llama-3.3-Nemotron-Super-49B-v1 - # trt backend - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:5000,500-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:5000,500-quant:fp8-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:500,2000-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:64-input_output_len:500,2000-quant:fp8-reqs:4-con:1-gpus:4] - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-input_output_len:5000,500-con:250-gpus:4] # timeout for h100 - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-input_output_len:5000,500-quant:fp8-con:250-gpus:4] # timeout for h100 - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-input_output_len:500,2000-con:250-gpus:4] # timeout for h100 - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-maxbs:16-input_output_len:500,2000-quant:fp8-con:250-gpus:4] # timeout for h100 - - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-streaming-bfloat16-maxbs:16-input_output_len:500,2000-quant:fp8-con:250-gpus:4] # timeout for h100 + #llama_v3.3_70b_instruct_fp8 + # FP8 specific tests +- condition: + terms: + supports_fp8: true + ranges: + system_gpu_count: + gte: 8 + wildcards: + gpu: + - '*b200*' + - '*h100*' + - '*h200*' + - '*l40s*' + - '*h20*' + tests: + #llama_v3.3_70b_instruct_fp8 + #trt backend + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-input_output_len:128,128-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-input_output_len:512,32-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-streaming-float8-maxbs:16-input_output_len:512,32-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:8-con:1-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-maxnt:5000-input_output_len:5000,500-quant:fp8-reqs:64-con:250-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-input_output_len:500,2000-quant:fp8-reqs:8-con:1-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-input_output_len:500,2000-quant:fp8-reqs:64-con:250-gpus:8] diff --git a/tests/integration/test_lists/qa/llm_perf_sanity.yml b/tests/integration/test_lists/qa/llm_perf_sanity.yml index 2853c656a8..b7293e74b2 100644 --- a/tests/integration/test_lists/qa/llm_perf_sanity.yml +++ b/tests/integration/test_lists/qa/llm_perf_sanity.yml @@ -1,5 +1,5 @@ version: 0.0.1 -trt_llm_release_perf_sanity_test: +llm_perf_sanity: - condition: ranges: system_gpu_count: @@ -14,28 +14,15 @@ trt_llm_release_perf_sanity_test: - '*h20*' tests: # E2E trtllm-bench - - perf/test_perf.py::test_perf[gpt_350m_moe-bench-float16-maxbs:64-input_output_len:128,128] - # E2E BERT - - perf/test_perf.py::test_perf[bert_large-bench-float16-maxbs:32-input_len:128+512] - - perf/test_perf.py::test_perf[roberta_base-bench-float16-maxbs:32-input_len:128+512] - - # Common models for all GPUs - - perf/test_perf.py::test_perf[starcoder2_3b-bench-float16-maxbs:1-input_output_len:512,200-reqs:10] - - perf/test_perf.py::test_perf[mamba_130m-bench-float16-input_output_len:128,128] - - perf/test_perf.py::test_perf[mamba_2.8b-bench-float16-input_output_len:128,128] - - # E2E ENC-DEC - - perf/test_perf.py::test_perf[mbart_large_50_many_to_one_mmt-cppmanager-exe-plugin_ifb-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[bart_large_cnn-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[t5-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[flan_t5_base-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[flan_t5_large-bench-float16-input_output_len:128,20] - - perf/test_perf.py::test_perf[whisper_large_v3-bench-float16-input_output_len:128,20] #llama_v3.1_8b_instruct #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:128,128] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-pytorch-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-pytorch-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:500,2000] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:500,2000] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:512,32] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:128,128] # Phi-4-multimodal-instruct @@ -44,38 +31,11 @@ trt_llm_release_perf_sanity_test: - perf/test_perf.py::test_perf[bielik_11b_v2.2_instruct-bench-pytorch-bfloat16-input_output_len:128,128] # Ministral-8B - perf/test_perf.py::test_perf[ministral_8b-bench-pytorch-bfloat16-input_output_len:500,2000-reqs:500-con:250] + - perf/test_perf.py::test_perf[phi_4_mini_instruct-bench-pytorch-bfloat16-input_output_len:500,2000] # Test list validation - test_list_validation.py::test_list_validation -# Tests for GPUs with memory > 25000MB -- condition: - ranges: - system_gpu_count: - gte: 1 - gpu_memory: - gt: 25000 - wildcards: - gpu: - - '*h100*' - - '*h200*' - - '*a100*' - - '*l40s*' - - '*l20*' - - '*h20*' - tests: - # E2E gptManagerBenchmark IFB - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-cppmanager-exe-static_batching-plugin_ifb-float16-bs:8+64-input_output_len:128,128+512,32] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-cppmanager-exe-plugin_ifb-bfloat16-gwp:0.0-input_output_len:128,128+512,32] - #llama_v3.1_8b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32] - #pytorch backend - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:128,128] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:512,32] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:128,128] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:512,32] - - perf/test_perf.py::test_perf[qwen2_7b_instruct-bench-float16-input_output_len:128,128] # FP8 specific tests - condition: @@ -90,15 +50,14 @@ trt_llm_release_perf_sanity_test: - '*h20*' tests: #llama_v3.1_8b_instruct_fp8 - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:128,128-quant:fp8] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32-quant:fp8] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:512,32] - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-maxnt:5000-input_output_len:5000,500-reqs:8-con:1] - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:500,2000-reqs:8-con:1] - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b_fp8-bench-pytorch-float8-maxbs:512-input_output_len:1000,1000-reqs:8-con:1] + - perf/test_perf.py::test_perf[bielik_11b_v2.2_instruct_fp8-bench-pytorch-float8-input_output_len:1000,1000-con:250] + - perf/test_perf.py::test_perf[ministral_8b_fp8-bench-pytorch-float8-input_output_len:500,2000-reqs:500-con:250] # Tests for systems with 2+ GPUs - condition: @@ -114,13 +73,7 @@ trt_llm_release_perf_sanity_test: - '*l20*' - '*h20*' tests: - - perf/test_perf.py::test_perf[t5-bench-float16-maxbs:1-input_output_len:128,20-gpus:2] - - perf/test_perf.py::test_perf[flan_t5_large-bench-float16-maxbs:1-input_output_len:128,20-gpus:2] #llama_v3.1_8b_instruct - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:128,128-quant:int8-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:128,128-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-streaming-bfloat16-input_output_len:128,128-gpus:2] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-maxbs:256-input_output_len:128,128-gpus:2] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-pytorch-streaming-bfloat16-input_output_len:128,128-gpus:2] @@ -142,10 +95,7 @@ trt_llm_release_perf_sanity_test: - '*l20*' - '*h20*' tests: - - perf/test_perf.py::test_perf[llama_v3.1_8b-cppmanager-exe-plugin_ifb-float16-mp-input_output_len:128,128-quant:fp8-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] - - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-quant:fp8-gpus:2] + #mixtral_8x7b_v0.1_fp8 pytorch backend - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:2] # Tests for systems with 2+ GPUs and high memory @@ -181,11 +131,9 @@ trt_llm_release_perf_sanity_test: tests: #llama_v3.1_70b #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:128,128-reqs:10-gpus:4] #pytorch backend + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-input_output_len:128,128-gpus:4] - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:128,128-reqs:10-gpus:4] - - perf/test_perf.py::test_perf[qwen_14b_chat-cppmanager-ootb_except_mha-float16-input_output_len:128,128-gpus:4] - - perf/test_perf.py::test_perf[starcoder_15.5b-cppmanager-exe-plugin_ifb-float16-maxbs:1-input_output_len:512,200-reqs:10-gpus:4] # FP8 specific tests - condition: @@ -201,6 +149,7 @@ trt_llm_release_perf_sanity_test: - '*l40s*' - '*h20*' tests: + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:2000,500-gpus:4] - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-pytorch-bfloat16-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:4-gpus:4] - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-float8-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:4-gpus:4] @@ -220,16 +169,11 @@ trt_llm_release_perf_sanity_test: - '*h20*' tests: #llama_v3.1_70b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:2000,200-reqs:10-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:200,2000-reqs:10-gpus:8] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:2000,200-reqs:10-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:200,2000-reqs:10-gpus:8] - perf/test_perf.py::test_perf[llama_v3.3_70b-bench-pytorch-bfloat16-input_output_len:500,2000-gpus:8] - perf/test_perf.py::test_perf[llama_v3.3_70b-bench-pytorch-bfloat16-input_output_len:2000,500-gpus:8] - - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:1-input_output_len:128,128-reqs:10-gpus:8] - # FP8 tests for systems with 8+ GPUs - condition: @@ -247,13 +191,13 @@ trt_llm_release_perf_sanity_test: tests: #llama_v3.1_70b - #trt backend - - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:128,128-quant:fp8-gpus:8] #pytorch backend - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:512,32-quant:fp8-gpus:8] #llama_v3.3_70b_instruct_fp8 #pytorch backend - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:8] + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1_instruct-bench-pytorch-float16-input_output_len:128,128-reqs:64-gpus:8] + - condition: terms: @@ -270,4 +214,9 @@ trt_llm_release_perf_sanity_test: tests: - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:128,128] - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-streaming-pytorch-float8-input_output_len:128,128] - - perf/test_perf.py::test_perf[qwen3_235b_a22b_fp8-bench-pytorch-float8-input_output_len:128,128-con:256-ep:8-gpus:8] + - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-streaming-float8-input_output_len:2000,500] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.6-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8] TIMEOUT(100) + - perf/test_perf.py::test_perf[qwen3_235b_a22b_fp8-bench-pytorch-float8-input_output_len:1000,2000-con:256-ep:8-gpus:8] TIMEOUT(45) diff --git a/tests/integration/test_lists/qa/llm_trt_integration_perf.yml b/tests/integration/test_lists/qa/llm_trt_integration_perf.yml index 1d2e3e0150..4841feacc5 100644 --- a/tests/integration/test_lists/qa/llm_trt_integration_perf.yml +++ b/tests/integration/test_lists/qa/llm_trt_integration_perf.yml @@ -1,5 +1,5 @@ version: 0.0.1 -trt_llm_integration_perf_test: +llm_trt_integration_perf: - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/qa/llm_trt_integration_perf_sanity.yml b/tests/integration/test_lists/qa/llm_trt_integration_perf_sanity.yml index 59cf7474a0..96152af29a 100644 --- a/tests/integration/test_lists/qa/llm_trt_integration_perf_sanity.yml +++ b/tests/integration/test_lists/qa/llm_trt_integration_perf_sanity.yml @@ -1,5 +1,5 @@ version: 0.0.1 -trt_llm_integration_perf_sanity_test: +llm_trt_integration_perf_sanity: - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 6caee5b69f..ce285faa79 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -25,6 +25,7 @@ l0_a10: - test_e2e.py::test_openai_chat_structural_tag_example - test_e2e.py::test_openai_chat_json_example - test_e2e.py::test_openai_chat_multimodal_example + - test_e2e.py::test_openai_prometheus - test_e2e.py::test_openai_lora - test_e2e.py::test_trtllm_serve_multimodal_example - test_e2e.py::test_trtllm_serve_lora_example @@ -99,6 +100,11 @@ l0_a10: - unittest/test_model_runner_cpp.py - unittest/llmapi/test_build_cache.py - unittest/llmapi/test_llm_utils.py + - unittest/llmapi/test_gc_utils.py + - unittest/llmapi/test_reasoning_parser.py + - unittest/llmapi/test_serialization.py + - unittest/llmapi/test_utils.py + - unittest/llmapi/test_llm_args.py - accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype # 1 min - accuracy/test_cli_flow.py::TestGpt2::test_beam_search # 1 min - accuracy/test_cli_flow.py::TestGpt2::test_beam_search_large # 6 mins diff --git a/tests/integration/test_lists/test-db/l0_a30.yml b/tests/integration/test_lists/test-db/l0_a30.yml index ce8058136f..5ec16996e7 100644 --- a/tests/integration/test_lists/test-db/l0_a30.yml +++ b/tests/integration/test_lists/test-db/l0_a30.yml @@ -18,8 +18,7 @@ l0_a30: - unittest/_torch/modeling -k "modeling_phi3" - unittest/_torch/modeling -k "modeling_qwen" - unittest/_torch/modeling -k "modeling_qwen_moe" - - unittest/_torch/modeling -k "modeling_exaone4" - - unittest/_torch/auto_deploy/unit/singlegpu + - unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison" - unittest/_torch/test_beam_search.py - condition: ranges: diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 7985cfa758..cb36129a14 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -17,8 +17,9 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4 - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_4] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_64] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] @@ -42,10 +43,21 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=False-fp8kv=False-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] + - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRITON] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8] @@ -54,11 +66,12 @@ l0_b200: - test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B] - test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct] - test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False] - - unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)" TIMEOUT (90) + - unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)" TIMEOUT (120) - unittest/_torch -k "modeling_llama" - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_deepseek" - - unittest/_torch/auto_deploy/unit/singlegpu + - unittest/_torch/modeling -k "modeling_gpt_oss" + - unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison" - unittest/_torch/speculative/test_eagle3.py - unittest/_torch/speculative/test_kv_cache_reuse.py - unittest/_torch/speculative/test_dynamic_spec_decode.py diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 61b8071726..fb3f518a68 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -15,6 +15,8 @@ l0_dgx_b200: backend: pytorch tests: - unittest/_torch/multi_gpu_modeling -k "deepseek" + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP] + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] @@ -68,5 +70,22 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True] - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True] + - accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRITON] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-TRITON] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-TRITON] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 3a8e6aa9c9..9b6d5b6f1f 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -33,6 +33,7 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0] + - disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] @@ -42,6 +43,8 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1] @@ -50,8 +53,11 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend - test_e2e.py::test_ptp_quickstart_advanced_bs1 - test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8] + - unittest/_torch/modeling/test_modeling_pixtral.py::test_tensor_parallelism - condition: ranges: system_gpu_count: @@ -69,6 +75,10 @@ l0_dgx_h100: - unittest/_torch/multi_gpu_modeling -k "deepseek" - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency] + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0] + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1] + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0] + - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype1] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] @@ -107,6 +117,10 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_mpi[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] + - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_overlap[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] @@ -118,6 +132,27 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16] - disaggregated/test_workers.py::test_workers_conditional_disaggregation_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16] +- condition: + ranges: + system_gpu_count: + gte: 4 + lte: 4 + wildcards: + gpu: + - '*h100*' + linux_distribution_name: ubuntu* + terms: + stage: pre_merge + backend: pytorch + auto_trigger: gpt_oss + tests: + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRITON] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-TRITON] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-TRITON] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4] - condition: ranges: system_gpu_count: @@ -183,5 +218,6 @@ l0_dgx_h100: terms: stage: post_merge backend: triton + auto_trigger: others tests: - triton_server/test_triton_llm.py::test_llmapi_backend[4-0-disableDecoupleMode-tensorrt_llm] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h200.yml b/tests/integration/test_lists/test-db/l0_dgx_h200.yml index 33542dd8d7..4266722545 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h200.yml @@ -136,6 +136,7 @@ l0_dgx_h200: - unittest/llmapi/test_llm_multi_gpu.py -m "gpu2 and part3" - unittest/llmapi/test_llm_multi_gpu.py -m "gpu4 and part0" - unittest/llmapi/test_llm_multi_gpu.py -m "not (gpu2 or gpu4)" + - unittest/llmapi/test_llm_kv_cache_events.py::test_llm_api_attention_dp_kv_events - examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8] - llmapi/test_llm_e2e.py::test_llmapi_exit_multi_gpu - test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] diff --git a/tests/integration/test_lists/test-db/l0_gb200.yml b/tests/integration/test_lists/test-db/l0_gb200.yml index 6e1f621947..ac39fbdc88 100644 --- a/tests/integration/test_lists/test-db/l0_gb200.yml +++ b/tests/integration/test_lists/test-db/l0_gb200.yml @@ -20,6 +20,9 @@ l0_gb200: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 0289c317a1..64f6498d09 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -15,11 +15,13 @@ l0_h100: tests: # ------------- PyTorch tests --------------- # Only key models in H100: llama/mixtral/nemotron/deepseek + - unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py::test_trtllm_bench_backend_comparison - unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)" TIMEOUT (90) - unittest/_torch -k "modeling_llama" - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_nemotron" - unittest/_torch/modeling -k "modeling_gemma3" + - unittest/_torch/modeling -k "modeling_gpt_oss" - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py @@ -30,6 +32,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=TRTLLM-torch_compile=False] @@ -40,6 +43,8 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=False-fp8kv=False-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True] @@ -80,7 +85,8 @@ l0_h100: - disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8] - - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_spec_dec_batch_slot_limit[False-EAGLE3-LLaMA3.1-Instruct-8B-Llama-3.1-8B-Instruct] + - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_spec_dec_batch_slot_limit[True-False-EAGLE3-LLaMA3.1-Instruct-8B-Llama-3.1-8B-Instruct] + - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_spec_dec_batch_slot_limit[False-False-EAGLE3-LLaMA3.1-Instruct-8B-Llama-3.1-8B-Instruct] - test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] - test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-] @@ -172,20 +178,20 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=FLASHINFER-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] @@ -208,6 +214,9 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance] - test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] - test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] - condition: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index ad2c7efd8f..49c9a6d010 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -169,7 +169,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-disable_attention_plugin-disable_context_fmha-tp:2-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5234058) examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5234058) examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-RobertaForQuestionAnswering-bert/roberta-base-squad2] SKIP (https://nvbugs/5234058) -disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271) full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) @@ -210,10 +209,8 @@ perf/test_perf.py::test_perf[bart_large_cnn-bench-float16-input_output_len:128,2 perf/test_perf.py::test_perf[mamba_130m-bench-float16-input_output_len:128,128] SKIP (https://nvbugspro.nvidia.com/bug/5295411) perf/test_perf.py::test_perf[bert_large-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411) perf/test_perf.py::test_perf[roberta_base-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411) -disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5328160) stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5328495) full:B200/examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5292737) -full:B200/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5295470) examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long] SKIP (https://nvbugs/5324976) examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_py_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5333849) examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5333818) @@ -234,33 +231,23 @@ accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/ examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128) examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128) stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5375646) -full:GH200/disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5375966) -accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620) +full:L40S/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620) test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-405B-FP8-llama-3.1-model/Llama-3.1-405B-Instruct-FP8] SKIP (https://nvbugs/5380570) test_e2e.py::test_ptp_quickstart_advanced_8gpus[Nemotron-Ultra-253B-nemotron-nas/Llama-3_1-Nemotron-Ultra-253B-v1] SKIP (https://nvbugs/5380570) -examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5385987) examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5385992) examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5387422) examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5387424) triton_server/test_triton_llm.py::test_llava_onevision[test_basic-False-1---False-True-False-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-max_utilization---1-1-1-False-tensorrt_llm_bls] SKIP (https://nvbugs/5396437) triton_server/test_triton_llm.py::test_llava_onevision[test_video-False-1---False-True-False-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-tensorrt_llm_bls] SKIP (https://nvbugs/5396437) -accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype SKIP (https://nvbugs/5401114) -test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] SKIP (https://nvbugs/5401114) -test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False] SKIP (https://nvbugs/5401114) examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233) examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b] SKIP (https://nvbugs/5401233) -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5409414) test_e2e.py::test_openai_multi_chat_example SKIP (https://nvbugs/5409416) test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5409420) -llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5410399) unittest/trt/attention/test_gpt_attention.py -k "partition0" SKIP (https://nvbugs/5412456) unittest/trt/attention/test_gpt_attention.py -k "partition1" SKIP (https://nvbugs/5412456) unittest/trt/attention/test_gpt_attention.py -k "partition2" SKIP (https://nvbugs/5412456) unittest/trt/attention/test_gpt_attention.py -k "partition3" SKIP (https://nvbugs/5412456) test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-image-False] SKIP (https://nvbugs/5414909) -unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5418673) -unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5418673) -examples/test_llama.py::test_llm_api_lookahead_decoding_1gpu[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct] SKIP (https://nvbugs/5419066) examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5141288) examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5419067) examples/test_qwen.py::test_llm_qwen_awq_single_gpu_summary[qwen2_vl_7b_instruct-nb:4] SKIP (https://nvbugs/5419068) @@ -270,20 +257,61 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124) examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5431139) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5431127) -accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5427801) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320) accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541) accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541) -accuracy/test_cli_flow.py::TestLlama3_3_70BInstruct::test_fp8_prequantized_tp4 SKIP (https://nvbugs/5409414) -accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 SKIP (https://nvbugs/5409414) accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype SKIP (https://nvbugs/5433543) accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope SKIP (https://nvbugs/5433543) accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5431139) -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5431139) -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5431139) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451) -accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384) +accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437405) +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 SKIP (https://nvbugs/5440241) +test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False] SKIP (https://nvbugs/5444060,https://nvbugs/5444095) +test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] SKIP (https://nvbugs/5444060,https://nvbugs/5444095) +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5433545) +examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B] SKIP (https://nvbugs/5444636) +accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive SKIP (https://nvbugs/5444627) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] SKIP (https://nvbugs/5444687) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] SKIP (https://nvbugs/5444687) +examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instruct] SKIP (https://nvbugs/5447530) +examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B] SKIP (https://nvbugs/5444636) +examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4] SKIP (https://nvbugs/5453709) +examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5453709) +examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5453709) +examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf] SKIP (https://nvbugs/5453742) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5403818) +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5442827,https://nvbugs/5445466) +test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-70B-FP8-llama-3.1-model/Llama-3.1-70B-Instruct-FP8] SKIP (https://nvbugs/5453992) +accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype SKIP (https://nvbugs/5454875) +accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] SKIP (https://nvbugs/5454898) +accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] SKIP (https://nvbugs/5454898) +examples/test_llm_api_with_mpi.py::test_llm_api_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5434372) +triton_server/test_triton.py::test_gpt_ib[gpt-ib] SKIP (https://nvbugs/5431116) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=True] SKIP (https://nvbugs/5457489) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=False] SKIP (https://nvbugs/5457489) +accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5457489) +disaggregated/test_workers.py::test_workers_kv_cache_events[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5457504) +accuracy/test_llm_api.py::TestMistral_Nemo_12B_Base::test_fp8 SKIP (https://nvbugs/5413197) +triton_server/test_triton.py::test_gpt_ib_streaming[gpt-ib-streaming] SKIP (https://nvbugs/5371349) +triton_server/test_triton.py::test_gpt_ib_ptuning[gpt-ib-ptuning] SKIP (https://nvbugs/5445624) +triton_server/test_triton.py::test_mistral_ib_mm[mistral-ib-mm] SKIP (https://nvbugs/5371343) +triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482) +triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485) +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320) +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend SKIP (https://nvbugs/5448437) +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend SKIP (https://nvbugs/5448437) +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] SKIP (https://nvbugs/5445466) +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] SKIP (https://nvbugs/5445466) +test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] SKIP (https://nvbugs/5459817) +llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796) +disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5459811) diff --git a/tests/scripts/perf-sanity/README.md b/tests/scripts/perf-sanity/README.md new file mode 100644 index 0000000000..cd8f5639e8 --- /dev/null +++ b/tests/scripts/perf-sanity/README.md @@ -0,0 +1,138 @@ +# TensorRT-LLM Benchmark Test System + +Benchmarking scripts for TensorRT-LLM serving performance tests with configuration-driven test cases and CSV report generation. + +## Overview + +- Run performance benchmarks across multiple model configurations +- Manage test cases through YAML configuration files +- Generate comprehensive CSV reports with complete test case coverage +- Support selective execution of specific test cases + +## Scripts Overview + +### 1. `benchmark_config.yaml` - Test Case Configuration +**Purpose**: Defines all benchmark test cases in a structured YAML format. + +**Structure**: +```yaml +test_cases: + - id: 1 + model: "70B-FP8" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 512 + isl: 1024 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] +``` + +**Configuration Fields**: +- `id`: Unique identifier for the test case +- `model`: Model name (e.g., "70B-FP8", "Scout-FP4") +- `gpus`: Number of GPUs to use +- `tp`: Tensor parallelism size +- `ep`: Expert parallelism size +- `attn_backend`: Attention backend ("TRTLLM", "FLASHINFER") +- `moe_backend`: MoE backend ("DEEPGEMM", "TRTLLM", "CUTLASS", "") +- `enable_attention_dp`: Enable attention data parallelism +- `free_gpu_mem_fraction`: GPU memory fraction to reserve +- `max_batch_size`: Maximum batch size +- `isl`: Input sequence length +- `osl`: Output sequence length +- `max_num_tokens`: Maximum number of tokens +- `moe_max_num_tokens`: Maximum number of tokens for MoE +- `concurrency_iterations`: List of [concurrency, iteration] pairs + + +### 2. `run_benchmark_serve.py` - Main Benchmark Runner +**Purpose**: Executes performance benchmarks based on YAML configuration files. + +**Usage**: +```bash +python run_benchmark_serve.py --output_folder <output_folder> --config_file <config_file> [--skip <skip_pattern>] [--select <select_pattern>] +``` + +**Arguments**: +- `--output_folder`: Directory to store benchmark results (required) +- `--config_file`: Path to YAML configuration file (required) +- `--skip`: Skip pattern for specific test cases/concurrencies (optional, default: no skipping) +- `--select`: Select pattern for specific test cases/concurrencies (optional, default: all test cases) + +**Examples**: +```bash +# Run all test cases +python run_benchmark_serve.py --output_folder results --config_file benchmark_config.yaml --skip default --select default + +# Skip specific test cases +python run_benchmark_serve.py --output_folder results --config_file benchmark_config.yaml --skip "2-1,4" + +# Run specific concurrencies from specific test cases +python run_benchmark_serve.py --output_folder results --config_file benchmark_config.yaml --select "1,2-3" + +``` + +**Skip Pattern**: +Format: `"test_case1,test_case2,test_case3"` or `"test_case1-concurrency1,test_case2-concurrency3"` +- `"2,4"`: Skip test cases 2 and 4 entirely +- `"2-1,4-2"`: Skip test case 2's 1st concurrency and test case 4's 2nd concurrency +- `"default"` or empty: No skipping (default) + +**Select Pattern**: +Format: `"test_case1,test_case2,test_case3"` or `"test_case1-concurrency1,test_case2-concurrency3"` +- `"1,3,5"`: Run only test cases 1, 3, and 5 (all concurrencies) +- `"1-1,2-3"`: Run test case 1's 1st concurrency and test case 2's 3rd concurrency +- `"default"` or empty: Run all test cases (default) + + +### 3. `parse_benchmark_results.py` - Results Parser +**Purpose**: Parses benchmark log files and generates comprehensive CSV reports with all test cases from the configuration file. + +**Usage**: +```bash +python parse_benchmark_results.py --input_folder <input_folder> --output_csv <output_csv> --config_file <config_file> +``` + +**Arguments**: +- `input_folder`: Folder containing benchmark log files (serve.*.log) (required) +- `output_csv`: Output CSV filename for the results table (required) +- `config_file`: Path to benchmark_config.yaml file (required) + +**Examples**: +```bash +python parse_benchmark_results.py --config_file ./benchmark_logs --output_csv results.csv --input_folder ./benchmark_config.yaml + +``` + +### 4. `benchmark-serve.sh` - SLURM Job Script +**Usage**: +```bash +sbatch benchmark-serve.sh [IMAGE] [bench_dir] [output_dir] [select_pattern] [skip_pattern] +``` + +**Parameters**: +- `IMAGE`: Docker image (default: tensorrt-llm-staging/release:main-x86_64) +- `bench_dir`: Directory containing config file and benchmark scripts (default: current directory) +- `output_dir`: Directory containing output logs and csv. (default: current directory) +- `select_pattern`: Select pattern (default: default - all test cases) +- `skip_pattern`: Skip pattern (default: default - no skipping) + +**Examples**: +```bash + +bench_dir="/path/to/benchmark/scripts" +output_dir="/path/to/store/output/files" +sbatch --reservation=RES--COM-3970 --qos=reservation -D ${output_dir} ${bench_dir}/benchmark-serve.sh urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm-staging/release:main-x86_64 ${bench_dir} ${output_dir} "1-1" "" + +``` diff --git a/tests/scripts/perf-sanity/benchmark-serve.sh b/tests/scripts/perf-sanity/benchmark-serve.sh new file mode 100755 index 0000000000..a3dd58cf72 --- /dev/null +++ b/tests/scripts/perf-sanity/benchmark-serve.sh @@ -0,0 +1,68 @@ +#! /usr/bin/bash +#SBATCH -N1 +#SBATCH -n1 +#SBATCH --time=08:00:00 +#SBATCH --gres=gpu:8 + +set -ex + +env && hostname && nvidia-smi + +DEFAULT_IMAGE="urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm-staging/release:main-x86_64" +IMAGE=${1:-$DEFAULT_IMAGE} +bench_dir=${2:-$(pwd)} +output_dir=${3:-$(pwd)} +select_pattern=${4:-default} +skip_pattern=${5:-default} + +start_time=$(date '+%Y-%m-%d-%H:%M:%S') +output_folder=${output_dir}/benchmark.run.${SLURM_JOB_ID}.${start_time}.${select_pattern}.${skip_pattern} + +# Validate bench_dir exists +if [[ ! -d "$bench_dir" ]]; then + echo "Error: bench_dir '$bench_dir' does not exist" + exit 1 +fi + +if [[ ! -d "$output_dir" ]]; then + echo "Error: output_dir '$output_dir' does not exist" + exit 1 +fi + +# the docker user is root, otherwise it can not write logs to this folder when running in scratch +chmod 777 ${output_dir} +mkdir -p ${output_folder} && chmod 777 ${output_folder} + +cd ${output_dir} + +report_head() { + echo "trtllm-serve Job ${SLURM_JOB_ID} started at:${start_time} on:$(hostname) under:$(pwd) + output: ${output_folder} " +} + +run_benchmark_and_parse() { + # Run benchmark and parse results in a single Docker container + docker run --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ + --gpus all \ + -v /home/scratch.trt_llm_data:/home/scratch.trt_llm_data:ro \ + -v $output_dir:$output_dir:rw \ + -v $bench_dir:$bench_dir:ro \ + -w `pwd` \ + --pull always \ + ${IMAGE} \ + bash -c " + echo 'Running benchmarks...' + python3 ${bench_dir}/run_benchmark_serve.py --output_folder ${output_folder} --config_file ${bench_dir}/benchmark_config.yaml --select ${select_pattern} --skip ${skip_pattern} + + echo 'Benchmarks completed. Generating CSV report...' + if [[ -f '${bench_dir}/parse_benchmark_results.py' ]]; then + python3 ${bench_dir}/parse_benchmark_results.py --config_file ${bench_dir}/benchmark_config.yaml --input_folder ${output_folder} --output_csv ${output_folder}.csv + echo 'CSV report generated successfully' + else + echo 'Warning: parse_benchmark_results.py not found' + fi + " +} + +report_head +run_benchmark_and_parse diff --git a/tests/scripts/perf-sanity/benchmark_config.yaml b/tests/scripts/perf-sanity/benchmark_config.yaml new file mode 100644 index 0000000000..6b5d25e698 --- /dev/null +++ b/tests/scripts/perf-sanity/benchmark_config.yaml @@ -0,0 +1,468 @@ +test_cases: + - id: 1 + model: "70B-FP8" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 1024 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 2 + model: "70B-FP8" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 8192 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 3 + model: "70B-FP8" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 1024 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 4 + model: "70B-FP8" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 8192 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 5 + model: "70B-FP4" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 1024 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 6 + model: "70B-FP4" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 8192 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 7 + model: "70B-FP4" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 1024 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 8 + model: "70B-FP4" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.9 + max_batch_size: 1024 + isl: 8192 + osl: 1024 + max_num_tokens: 16384 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 9 + model: "Scout-FP8" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 10 + model: "Scout-FP8" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 8192 + osl: 1024 + max_num_tokens: 9334 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 11 + model: "Scout-FP8" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 12 + model: "Scout-FP8" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 8192 + osl: 1024 + max_num_tokens: 9334 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 13 + model: "Scout-FP4" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 14 + model: "Scout-FP4" + gpus: 1 + tp: 1 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 8192 + osl: 1024 + max_num_tokens: 9334 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 15 + model: "Scout-FP4" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 16 + model: "Scout-FP4" + gpus: 4 + tp: 4 + ep: 1 + attn_backend: "TRTLLM" + moe_backend: "" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 8192 + osl: 1024 + max_num_tokens: 9334 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + - [64, 5] + - [512, 2] + + - id: 17 + model: "R1-FP8" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "DEEPGEMM" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 1024 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + + - id: 18 + model: "R1-FP8" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "DEEPGEMM" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 1024 + isl: 8192 + osl: 1024 + max_num_tokens: 9344 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + + - id: 19 + model: "R1-FP8" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "DEEPGEMM" + enable_attention_dp: true + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: 37376 + concurrency_iterations: + - [64, 5] + - [512, 2] + - [4096, 2] + + - id: 20 + model: "R1-FP8" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "DEEPGEMM" + enable_attention_dp: true + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 8192 + osl: 1024 + max_num_tokens: 9344 + moe_max_num_tokens: 9344 + concurrency_iterations: + - [64, 5] + - [512, 2] + - [4096, 2] + + - id: 21 + model: "R1-FP4" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "TRTLLM" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 1024 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + + - id: 22 + model: "R1-FP4" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "TRTLLM" + enable_attention_dp: false + free_gpu_mem_fraction: 0.8 + max_batch_size: 1024 + isl: 8192 + osl: 1024 + max_num_tokens: 9344 + moe_max_num_tokens: "" + concurrency_iterations: + - [1, 10] + - [8, 10] + + - id: 23 + model: "R1-FP4" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "CUTLASS" + enable_attention_dp: true + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 1024 + osl: 1024 + max_num_tokens: 2176 + moe_max_num_tokens: 37376 + concurrency_iterations: + - [64, 5] + - [512, 2] + - [4096, 2] + + - id: 24 + model: "R1-FP4" + gpus: 8 + tp: 8 + ep: 8 + attn_backend: "TRTLLM" + moe_backend: "CUTLASS" + enable_attention_dp: true + free_gpu_mem_fraction: 0.8 + max_batch_size: 512 + isl: 8192 + osl: 1024 + max_num_tokens: 9344 + moe_max_num_tokens: 9344 + concurrency_iterations: + - [64, 5] + - [512, 2] + - [4096, 2] diff --git a/tests/scripts/perf-sanity/parse_benchmark_results.py b/tests/scripts/perf-sanity/parse_benchmark_results.py new file mode 100644 index 0000000000..ae3c87d3b4 --- /dev/null +++ b/tests/scripts/perf-sanity/parse_benchmark_results.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +import argparse +import re +import sys +from pathlib import Path + +import pandas as pd +import yaml + + +def extract_config_from_log_content(log_file): + """ + Extract configuration from log file content using "Completed benchmark with Configuration:" pattern + """ + try: + with open(log_file, 'r') as f: + for line in f: + if "Completed benchmark with Configuration:" in line: + # Extract values using regex patterns + model_label_match = re.search(r'model_label=([^,]+)', line) + gpus_match = re.search(r'GPUs=(\d+)', line) + tp_match = re.search(r'TP=(\d+)', line) + ep_match = re.search(r'EP=(\d+)', line) + attn_backend_match = re.search(r'attn_backend=([^,]+)', + line) + moe_backend_match = re.search(r'moe_backend=([^,]+)', line) + enable_attention_dp_match = re.search( + r'enable_attention_dp=([^,]+)', line) + free_gpu_mem_fraction_match = re.search( + r'free_gpu_mem_fraction=([^,]+)', line) + max_batch_size_match = re.search(r'max_batch_size=(\d+)', + line) + isl_match = re.search(r'ISL=(\d+)', line) + osl_match = re.search(r'OSL=(\d+)', line) + max_num_tokens_match = re.search(r'max_num_tokens=(\d+)', + line) + moe_max_num_tokens_match = re.search( + r'moe_max_num_tokens=([^,]+)', line) + concurrency_match = re.search(r'Concurrency=(\d+)', line) + + # Extract values, use empty string if not found + model_label = model_label_match.group( + 1) if model_label_match else "" + gpus = int(gpus_match.group(1)) if gpus_match else "" + tp = int(tp_match.group(1)) if tp_match else "" + ep = int(ep_match.group(1)) if ep_match else "" + attn_backend = attn_backend_match.group( + 1) if attn_backend_match else "" + moe_backend = moe_backend_match.group( + 1) if moe_backend_match else "" + enable_attention_dp = enable_attention_dp_match.group( + 1) if enable_attention_dp_match else "" + free_gpu_mem_fraction = float( + free_gpu_mem_fraction_match.group( + 1)) if free_gpu_mem_fraction_match else "" + max_batch_size = int(max_batch_size_match.group( + 1)) if max_batch_size_match else "" + isl = int(isl_match.group(1)) if isl_match else "" + osl = int(osl_match.group(1)) if osl_match else "" + max_num_tokens = int(max_num_tokens_match.group( + 1)) if max_num_tokens_match else "" + moe_max_num_tokens_str = moe_max_num_tokens_match.group( + 1) if moe_max_num_tokens_match else "" + concurrency = int( + concurrency_match.group(1)) if concurrency_match else "" + + # Handle moe_max_num_tokens (could be "N/A", empty, or a number) + moe_max_num_tokens = "" + if moe_max_num_tokens_str and moe_max_num_tokens_str != "N/A": + try: + moe_max_num_tokens = int(moe_max_num_tokens_str) + except ValueError: + moe_max_num_tokens = "" + elif not moe_max_num_tokens_str: + moe_max_num_tokens = "" + + # Handle enable_attention_dp (convert string to boolean) + enable_attention_dp_bool = "" + if enable_attention_dp: + enable_attention_dp_bool = enable_attention_dp.lower( + ) == "true" + + # Check if all required fields are present (not empty strings) + if (model_label and gpus != "" and tp != "" and ep != "" + and attn_backend and free_gpu_mem_fraction != "" + and max_batch_size != "" and isl != "" and osl != "" + and max_num_tokens != "" and concurrency != ""): + return { + 'model_name': model_label, + 'gpus': gpus, + 'tp': tp, + 'ep': ep, + 'attn_backend': attn_backend, + 'moe_backend': moe_backend, + 'enable_attention_dp': enable_attention_dp_bool, + 'free_gpu_mem_fraction': free_gpu_mem_fraction, + 'max_batch_size': max_batch_size, + 'isl': isl, + 'osl': osl, + 'max_num_tokens': max_num_tokens, + 'moe_max_num_tokens': moe_max_num_tokens, + 'concurrency': concurrency, + 'found_in_log': True + } + else: + print( + f"Warning: Incomplete configuration in {log_file} - missing required fields" + ) + return None + except Exception as e: + print(f"Warning: Could not read {log_file}: {e}") + + return None + + +def extract_metrics_from_log(log_file): + """ + Extract Total Token throughput and User throughput from log file + """ + total_throughput = "" + user_throughput = "" + + try: + with open(log_file, 'r') as f: + for line in f: + if "Total Token throughput (tok/s):" in line: + parts = line.strip().split() + if len(parts) >= 5: + total_throughput = parts[4] + elif "User throughput (tok/s):" in line: + parts = line.strip().split() + if len(parts) >= 4: + user_throughput = parts[3] + except Exception as e: + print(f"Warning: Could not read {log_file}: {e}") + + return total_throughput, user_throughput + + +def generate_all_test_cases(benchmark_config): + """ + Generate all test cases from benchmark_config.yaml including all concurrency iterations + """ + all_test_cases = [] + + for test_case in benchmark_config['test_cases']: + base_config = { + 'model_name': test_case['model'], + 'gpus': test_case['gpus'], + 'tp': test_case['tp'], + 'ep': test_case['ep'], + 'attn_backend': test_case['attn_backend'], + 'moe_backend': test_case['moe_backend'], + 'enable_attention_dp': test_case['enable_attention_dp'], + 'free_gpu_mem_fraction': test_case['free_gpu_mem_fraction'], + 'max_batch_size': test_case['max_batch_size'], + 'isl': test_case['isl'], + 'osl': test_case['osl'], + 'max_num_tokens': test_case['max_num_tokens'], + 'moe_max_num_tokens': test_case['moe_max_num_tokens'], + } + + # Generate a test case for each concurrency iteration + for concurrency, iterations in test_case['concurrency_iterations']: + test_case_config = base_config.copy() + test_case_config['concurrency'] = concurrency + test_case_config['iterations'] = iterations + test_case_config['TPS/System'] = "" + test_case_config['TPS/User'] = "" + all_test_cases.append(test_case_config) + + return all_test_cases + + +def match_log_to_test_case(log_config, test_case): + """ + Check if a log configuration matches a test case configuration + Returns True if all parameters match exactly + """ + if not log_config: + return False + + # Check if all key parameters match exactly + return (log_config['model_name'] == test_case['model_name'] + and log_config['gpus'] == test_case['gpus'] + and log_config['tp'] == test_case['tp'] + and log_config['ep'] == test_case['ep'] + and log_config['attn_backend'] == test_case['attn_backend'] + and log_config['moe_backend'] == test_case['moe_backend'] + and log_config['enable_attention_dp'] + == test_case['enable_attention_dp'] + and log_config['free_gpu_mem_fraction'] + == test_case['free_gpu_mem_fraction'] + and log_config['max_batch_size'] == test_case['max_batch_size'] + and log_config['isl'] == test_case['isl'] + and log_config['osl'] == test_case['osl'] + and log_config['max_num_tokens'] == test_case['max_num_tokens'] and + (log_config['moe_max_num_tokens'] == test_case['moe_max_num_tokens'] + or (not log_config['moe_max_num_tokens'] + and not test_case['moe_max_num_tokens'])) + and log_config['concurrency'] == test_case['concurrency']) + + +def create_test_case_row(test_case): + """ + Create a row for a test case with empty performance data + """ + return { + 'model_name': test_case['model_name'], + 'GPUs': test_case['gpus'], + 'TP': test_case['tp'], + 'EP': test_case['ep'], + 'attn_backend': test_case['attn_backend'], + 'moe_backend': test_case['moe_backend'], + 'enable_attention_dp': test_case['enable_attention_dp'], + 'free_gpu_mem_fraction': test_case['free_gpu_mem_fraction'], + 'max_batch_size': test_case['max_batch_size'], + 'ISL': test_case['isl'], + 'OSL': test_case['osl'], + 'max_num_tokens': test_case['max_num_tokens'], + 'moe_max_num_tokens': test_case['moe_max_num_tokens'], + 'Concurrency': test_case['concurrency'], + 'Iterations': test_case['iterations'], + 'TPS/System': test_case['TPS/System'], + 'TPS/User': test_case['TPS/User'], + } + + +def parse_benchmark_results(input_folder, output_csv, config_file): + """ + Parse benchmark results and generate CSV table + """ + input_folder = Path(input_folder) + config_file = Path(config_file) + + # Validate inputs + if not input_folder.exists(): + print(f"Error: Input folder '{input_folder}' does not exist") + return + + if not input_folder.is_dir(): + print(f"Error: '{input_folder}' is not a directory") + return + + if not config_file.exists(): + print(f"Error: Config file '{config_file}' does not exist") + return + + # Load benchmark configuration + try: + with open(config_file, 'r') as f: + benchmark_config = yaml.safe_load(f) + print(f"Loaded benchmark configuration from: {config_file}") + except Exception as e: + print(f"Error: Could not load {config_file}: {e}") + return + + # Generate all test cases from config + all_test_cases = generate_all_test_cases(benchmark_config) + print(f"Generated {len(all_test_cases)} test cases from configuration") + + # Find all serve.*.log files + log_files = list(input_folder.glob("serve.*.log")) + print(f"Found {len(log_files)} log files to process") + + # Process each log file + matched_count = 0 + for log_file in log_files: + print(f"Processing: {log_file.name}") + + # Extract configuration from log + log_config = extract_config_from_log_content(log_file) + if not log_config: + print(f" Skipped - could not parse configuration") + continue + + # Extract performance metrics + total_throughput, user_throughput = extract_metrics_from_log(log_file) + + # Find matching test case in table + matched = False + for test_case in all_test_cases: + if match_log_to_test_case(log_config, test_case): + # Update performance data + test_case['TPS/System'] = total_throughput + test_case['TPS/User'] = user_throughput + matched = True + matched_count += 1 + break + + if not matched: + print( + f" Skipped - no matching test case found for test case {test_case}" + ) + + print(f"Successfully matched {matched_count} log files to test cases") + + table_rows = [] + for test_case in all_test_cases: + row = create_test_case_row(test_case) + table_rows.append(row) + + # Add empty rows between different test configurations + final_table = [] + for i, row in enumerate(table_rows): + if i > 0: + prev_row = table_rows[i - 1] + # Check if any key parameters changed + if (row['model_name'] != prev_row['model_name'] + or row['TP'] != prev_row['TP'] + or row['EP'] != prev_row['EP'] + or row['moe_backend'] != prev_row['moe_backend'] + or row['ISL'] != prev_row['ISL'] + or row['OSL'] != prev_row['OSL']): + # Add empty row + empty_row = {key: '' for key in row.keys()} + final_table.append(empty_row) + + final_table.append(row) + + # Create DataFrame and save to CSV + df = pd.DataFrame(final_table) + + # Ensure output directory exists + output_path = Path(output_csv) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Save to CSV + df.to_csv(output_path, index=False) + + # Print summary + print(f"\nCSV table saved to: {output_path}") + print( + f"Total rows: {len(final_table)} (including {len(final_table) - len(table_rows)} empty separator rows)" + ) + + return df + + +def main(): + parser = argparse.ArgumentParser( + description= + "Script to parse benchmark metrics from a specified folder and generate CSV table", + epilog= + "Example: python parse_benchmark_results.py ./benchmark_logs results.csv ./benchmark_config.yaml" + ) + parser.add_argument( + "--input_folder", + help="Folder containing benchmark log files (serve.*.log)") + parser.add_argument("--output_csv", + help="Output CSV filename for the results table") + parser.add_argument("--config_file", + help="Path to benchmark_config.yaml file") + + args = parser.parse_args() + + # Validate inputs + input_folder_path = Path(args.input_folder) + config_file_path = Path(args.config_file) + + if not input_folder_path.exists(): + print(f"Error: Input folder '{args.input_folder}' not found.") + sys.exit(1) + if not input_folder_path.is_dir(): + print(f"Error: '{args.input_folder}' is not a directory.") + sys.exit(1) + if not config_file_path.exists(): + print(f"Error: Config file '{args.config_file}' not found.") + sys.exit(1) + + print(f"Using input folder: {input_folder_path}") + print(f"Using config file: {config_file_path}") + print(f"Output will be saved to: {args.output_csv}") + print() + + parse_benchmark_results(args.input_folder, args.output_csv, + args.config_file) + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/perf-sanity/run_benchmark_serve.py b/tests/scripts/perf-sanity/run_benchmark_serve.py new file mode 100644 index 0000000000..2d4928ae32 --- /dev/null +++ b/tests/scripts/perf-sanity/run_benchmark_serve.py @@ -0,0 +1,779 @@ +#!/usr/bin/env python3 +import argparse +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Set + +import requests +import yaml + + +class BenchmarkRunner: + + def __init__(self, + output_folder: str, + config_file: str, + skip_pattern: str = None, + select_pattern: str = None): + self.output_folder = Path(output_folder) + self.config_file = Path(config_file) + + # Treat empty or "default" values as None (default behavior) + self.skip_pattern = None if not skip_pattern or skip_pattern.lower( + ) == "default" else skip_pattern + self.select_pattern = None if not select_pattern or select_pattern.lower( + ) == "default" else select_pattern + + self.skip_test_cases: Set[int] = set() + self.skip_concurrencies: Dict[int, Set[int]] = {} + self.select_test_cases: Set[int] = set() + self.select_concurrencies: Dict[int, Set[int]] = {} + + if self.skip_pattern: + self.parse_skip_pattern(self.skip_pattern) + + if self.select_pattern: + self.parse_select_pattern(self.select_pattern) + + # Execution plan: {test_case_id: [concurrency_indices]} + self.execution_plan: Dict[int, List[int]] = {} + + # Model path mapping + self.model_paths = { + "70B-FP4": + "/home/scratch.trt_llm_data/llm-models/llama-3.3-models/Llama-3.3-70B-Instruct-FP4", + "70B-FP8": + "/home/scratch.trt_llm_data/llm-models/llama-3.3-models/Llama-3.3-70B-Instruct-FP8", + "Scout-FP4": + "/home/scratch.trt_llm_data/llm-models/llama4-models/Llama-4-Scout-17B-16E-Instruct-FP4", + "Scout-FP8": + "/home/scratch.trt_llm_data/llm-models/llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8", + "R1-FP8": + "/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1/", + "R1-FP4": + "/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-0528-FP4" + } + + # Set environment variables + os.environ['TQDM_MININTERVAL'] = '1000' + os.environ['PRINT_ITER_LOG'] = 'false' + + # Capture system information + self.node_name = self.get_node_name() + self.gpu_info = self.get_gpu_info() + + # Change to output directory + os.chdir(self.output_folder) + + def get_node_name(self) -> str: + """Get the current node name""" + try: + result = subprocess.run("hostname", + shell=True, + capture_output=True, + text=True, + check=True) + return result.stdout.strip() + except (subprocess.CalledProcessError, FileNotFoundError): + return "unknown" + + def get_gpu_info(self) -> str: + """Get GPU information from nvidia-smi""" + try: + result = subprocess.run("nvidia-smi", + shell=True, + capture_output=True, + text=True, + check=True) + return result.stdout + except subprocess.CalledProcessError as e: + return f"nvidia-smi failed with error code {e.returncode}\nError output: {e.stderr}" + except FileNotFoundError: + return "nvidia-smi not found" + + def parse_skip_pattern(self, skip_pattern: str) -> None: + """Parse skip pattern like '2,4-1' to determine what to skip""" + if not skip_pattern: + return + + parts = skip_pattern.split(',') + for part in parts: + part = part.strip() + if not part: # Skip empty parts + continue + + if '-' in part: + # Format: "test_case-concurrency_index" (1-based) + try: + test_case_str, concurrency_str = part.split('-') + test_case_id = int(test_case_str) + concurrency_index = int( + concurrency_str) - 1 # Convert to 0-based + + if test_case_id not in self.skip_concurrencies: + self.skip_concurrencies[test_case_id] = set() + self.skip_concurrencies[test_case_id].add(concurrency_index) + except ValueError: + raise ValueError( + f"Invalid skip pattern '{part}'. Expected format: 'test_case-concurrency_index' (e.g., '2-1')" + ) + else: + # Format: "test_case" - skip entire test case + try: + test_case_id = int(part) + self.skip_test_cases.add(test_case_id) + except ValueError: + raise ValueError( + f"Invalid test case ID '{part}' in skip pattern. Must be a valid integer." + ) + + print(f"Skipping test cases: {sorted(self.skip_test_cases)}") + print(f"Skipping concurrencies: {self.skip_concurrencies}") + + def parse_select_pattern(self, select_pattern: str) -> None: + """Parse select pattern like '1,3,5' or '1-1,2-3' to determine which test cases/concurrencies to run""" + if not select_pattern: + return + + self.select_concurrencies: Dict[int, Set[int]] = {} + + parts = select_pattern.split(',') + for part in parts: + part = part.strip() + if not part: # Skip empty parts + continue + + if '-' in part: + # Format: "test_case-concurrency_index" (1-based) + try: + test_case_str, concurrency_str = part.split('-') + test_case_id = int(test_case_str) + concurrency_index = int( + concurrency_str) - 1 # Convert to 0-based + + if test_case_id not in self.select_concurrencies: + self.select_concurrencies[test_case_id] = set() + self.select_concurrencies[test_case_id].add( + concurrency_index) + except ValueError: + raise ValueError( + f"Invalid select pattern '{part}'. Expected format: 'test_case-concurrency_index' (e.g., '2-1')" + ) + else: + # Format: "test_case" - select entire test case + try: + test_case_id = int(part) + self.select_test_cases.add(test_case_id) + except ValueError: + raise ValueError( + f"Invalid test case ID '{part}' in select pattern. Must be a valid integer." + ) + + print(f"Selected test cases: {sorted(self.select_test_cases)}") + print(f"Selected concurrencies: {self.select_concurrencies}") + + def build_execution_plan(self, test_cases: List[Dict[str, Any]]) -> None: + """Build execution plan by analyzing config file, skip_pattern, and select_pattern""" + self.execution_plan.clear() + + # Step 1: Initialize execution plan based on select_pattern + if not self.select_pattern: + # If select_pattern is empty or default, include all test cases with all concurrencies + for test_case in test_cases: + test_case_id = test_case['id'] + all_concurrencies = list( + range(len(test_case['concurrency_iterations']))) + self.execution_plan[test_case_id] = all_concurrencies + else: + # If select_pattern is specified, only include selected test cases and concurrencies + for test_case in test_cases: + test_case_id = test_case['id'] + + # Check if this test case is selected + if test_case_id in self.select_test_cases: + # Test case is selected - include all concurrencies + all_concurrencies = list( + range(len(test_case['concurrency_iterations']))) + self.execution_plan[test_case_id] = all_concurrencies + elif test_case_id in self.select_concurrencies: + # Specific concurrencies are selected for this test case + selected_concurrencies = list( + self.select_concurrencies[test_case_id]) + # Validate that selected concurrencies exist in config + max_concurrency_index = len( + test_case['concurrency_iterations']) - 1 + valid_concurrencies = [ + c for c in selected_concurrencies + if 0 <= c <= max_concurrency_index + ] + if valid_concurrencies: + self.execution_plan[test_case_id] = valid_concurrencies + + # Step 2: Apply skip_pattern to remove test cases and concurrencies + # Remove entire test cases that are in skip_test_cases + for test_case_id in self.skip_test_cases: + if test_case_id in self.execution_plan: + del self.execution_plan[test_case_id] + + # Remove specific concurrencies that are in skip_concurrencies + for test_case_id, skip_concurrency_indices in self.skip_concurrencies.items( + ): + if test_case_id in self.execution_plan: + # Remove skipped concurrencies from the list + remaining_concurrencies = [ + c for c in self.execution_plan[test_case_id] + if c not in skip_concurrency_indices + ] + if remaining_concurrencies: + self.execution_plan[test_case_id] = remaining_concurrencies + else: + # If no concurrencies remain, remove the entire test case + del self.execution_plan[test_case_id] + + # Step 3: Clean up - remove test cases with empty concurrency lists + # (This should not happen with the above logic, but just to be safe) + test_cases_to_remove = [] + for test_case_id, concurrencies in self.execution_plan.items(): + if not concurrencies: + test_cases_to_remove.append(test_case_id) + + for test_case_id in test_cases_to_remove: + del self.execution_plan[test_case_id] + + def print_execution_plan(self, test_cases: List[Dict[str, Any]]) -> None: + """Print which test cases and concurrencies will be executed""" + print("\n" + "=" * 80) + print("EXECUTION PLAN") + print("=" * 80) + + total_test_cases = 0 + total_concurrencies = 0 + + for test_case in test_cases: + test_case_id = test_case['id'] + model_label = test_case['model'] + + # Check if this test case is in execution plan + if test_case_id not in self.execution_plan: + print(f"Test Case {test_case_id}: {model_label} - SKIPPED") + continue + + total_test_cases += 1 + print(f"\nTest Case {test_case_id}: {model_label}") + print( + f" Config: GPUs={test_case['gpus']}, TP={test_case['tp']}, EP={test_case['ep']}, attn_backend={test_case['attn_backend']}, moe_backend={test_case['moe_backend']}" + ) + + # Get concurrencies from execution plan + concurrencies_to_run = [] + for concurrency_index in self.execution_plan[test_case_id]: + concurrency, iteration = test_case['concurrency_iterations'][ + concurrency_index] + concurrencies_to_run.append( + (concurrency_index + 1, concurrency, + iteration)) # +1 for 1-based display + total_concurrencies += 1 + + print( + f" Concurrencies to run ({len(concurrencies_to_run)}/{len(test_case['concurrency_iterations'])}):" + ) + for concurrency_num, concurrency, iteration in concurrencies_to_run: + print( + f" {concurrency_num}. Concurrency={concurrency}, Iteration={iteration}" + ) + + print("\n" + "=" * 80) + print( + f"SUMMARY: {total_test_cases} test cases, {total_concurrencies} concurrencies will be executed" + ) + print("=" * 80 + "\n") + + def generate_extra_llm_api_config(self, test_case: Dict[str, Any]) -> str: + """Generate extra-llm-api-config.yml content""" + config_lines = [ + "print_iter_log: true", + f"enable_attention_dp: {str(test_case['enable_attention_dp']).lower()}", + "disable_overlap_scheduler: false", + "stream_interval: 10", + f"attn_backend: {test_case['attn_backend']}", + "cuda_graph_config:", + " enable_padding: true", + f" max_batch_size: {test_case['max_batch_size']}", + "kv_cache_config:", + " dtype: fp8", + f" free_gpu_memory_fraction: {test_case['free_gpu_mem_fraction']}", + " enable_block_reuse: false", + ] + + # Add moe_config if moe_backend is specified + if test_case['moe_backend']: + config_lines.append("moe_config:") + config_lines.append(f" backend: {test_case['moe_backend']}") + + if test_case['moe_max_num_tokens']: + config_lines.append( + f" max_num_tokens: {test_case['moe_max_num_tokens']}") + + return "\n".join(config_lines) + + def wait_for_server(self, + server_pid: int, + server_log_filename: str, + max_attempts: int = 360) -> bool: + """Wait for server to be ready""" + print("Waiting for trtllm-serve to be ready...") + + for attempt in range(1, max_attempts + 1): + # Check if server is still running + try: + os.kill(server_pid, 0) # Check if process exists + except OSError: + print("Error: Server process has died") + return False + + # Check server log for runtime errors + if self.check_for_runtime_error(server_log_filename): + print( + f"RuntimeError detected in server log: {server_log_filename}" + ) + print("Killing server process due to runtime error") + try: + subprocess.run(f"kill -9 {server_pid}", + shell=True, + check=False) + subprocess.run(f"wait {server_pid} 2>/dev/null || true", + shell=True, + check=False) + except Exception as e: + print(f"Warning: Error killing server process: {e}") + return False + + # Try to connect to server + try: + response = requests.get("http://localhost:8000/v1/models", + timeout=5) + if response.status_code == 200: + print( + f"Server is ready! HTTP status: {response.status_code}") + return True + except requests.RequestException: + pass + + print( + f"Attempt {attempt}/{max_attempts}: Server not ready yet, waiting..." + ) + time.sleep(10) + + print( + f"Error: Server did not become ready after {max_attempts} attempts") + return False + + def check_for_runtime_error(self, log_file_path: str) -> bool: + """Check if RuntimeError exists in log file""" + try: + if os.path.exists(log_file_path): + with open(log_file_path, 'r') as f: + content = f.read() + if "RuntimeError" in content or "runtime error" in content or "illegal memory access" in content or "terminate called" in content: + return True + except Exception as e: + print(f"Warning: Could not read log file {log_file_path}: {e}") + return False + + def run_benchmark(self, test_case: Dict[str, Any], concurrency: int, + iteration: int, model_path: str, + server_log_filename: str) -> bool: + """Run a single benchmark with monitoring. Returns True if successful, False if should skip test case""" + num_prompts = concurrency * iteration + + print( + f'Running benchmark with concurrency: {concurrency}, iteration: {iteration}, num-prompts: {num_prompts}' + ) + + # Build benchmark command + benchmark_cmd = [ + "python", "-m", "tensorrt_llm.serve.scripts.benchmark_serving", + "--model", model_path, "--dataset-name", "random", "--random-ids", + "--num-prompts", + str(num_prompts), "--random-input-len", + str(test_case['isl']), "--random-output-len", + str(test_case['osl']), "--random-range-ratio", "0.0", + "--ignore-eos", "--percentile-metrics", "ttft,tpot,itl,e2el", + "--max-concurrency", + str(concurrency) + ] + + print(f'Running benchmark with command:') + print(' '.join(benchmark_cmd)) + print() + + # Prepare log filename + benchmark_log_filename = ( + f"serve.{test_case['model']}.tp{test_case['tp']}.ep{test_case['ep']}." + f"attn{test_case['attn_backend']}.moe{test_case['moe_backend']}." + f"gpu{test_case['free_gpu_mem_fraction']}.batch{test_case['max_batch_size']}." + f"isl{test_case['isl']}.osl{test_case['osl']}." + f"tokens{test_case['max_num_tokens']}.moetokens{test_case['moe_max_num_tokens']}." + f"concurrency{concurrency}.iter{iteration}.log") + + try: + with open(benchmark_log_filename, 'w') as f: + f.write(f"GPU Info: {self.gpu_info}\n") + + # Start benchmark as subprocess + with open(benchmark_log_filename, 'a') as log_file: + benchmark_process = subprocess.Popen(benchmark_cmd, + stdout=log_file, + stderr=subprocess.STDOUT) + + # Monitor logs every 60 seconds with timeout + print( + f"Starting log monitoring for benchmark process (PID: {benchmark_process.pid})" + ) + + start_time = time.time() + timeout_seconds = 3600 # 1 hour timeout + + while benchmark_process.poll() is None: # Process is still running + time.sleep(60) # Wait 60 seconds + + # Check if benchmark has been running for more than 1 hour + elapsed_time = time.time() - start_time + if elapsed_time > timeout_seconds: + print( + f"Benchmark timeout after {elapsed_time:.0f} seconds (>{timeout_seconds} seconds)" + ) + print("Killing benchmark process due to timeout") + try: + subprocess.run(f"kill -9 {benchmark_process.pid}", + shell=True, + check=False) + benchmark_process.wait(timeout=10) + except Exception as e: + print(f"Warning: Error killing benchmark process: {e}") + return False # Signal to skip test case + + print( + f"Checking logs for RuntimeError... (benchmark PID: {benchmark_process.pid}, elapsed: {elapsed_time:.0f}s)" + ) + + # Check server log for RuntimeError + if self.check_for_runtime_error(server_log_filename): + print( + f"RuntimeError found in server log: {server_log_filename}" + ) + print( + "Killing benchmark process and skipping this test case") + try: + subprocess.run(f"kill -9 {benchmark_process.pid}", + shell=True, + check=False) + benchmark_process.wait(timeout=10) + except Exception as e: + print(f"Warning: Error killing benchmark process: {e}") + return False # Signal to skip test case + + # Check benchmark log for RuntimeError + if self.check_for_runtime_error(benchmark_log_filename): + print( + f"RuntimeError found in benchmark log: {benchmark_log_filename}" + ) + print( + "Killing benchmark process and skipping this test case") + try: + subprocess.run(f"kill -9 {benchmark_process.pid}", + shell=True, + check=False) + benchmark_process.wait(timeout=10) + except Exception as e: + print(f"Warning: Error killing benchmark process: {e}") + return False # Signal to skip test case + + # Process completed, check final return code + return_code = benchmark_process.returncode + if return_code != 0: + print( + f"Benchmark process completed with error code: {return_code}" + ) + + # Read and display error output + try: + with open(benchmark_log_filename, 'r') as f: + error_content = f.read() + print( + f"Benchmark error output:\n{error_content[-1000:]}" + ) # Last 1000 chars + except Exception as e: + print(f"Could not read benchmark log: {e}") + + print( + f"Skipping this concurrency level and continuing with next one..." + ) + print("-----------------------------------------") + return True # Continue with next concurrency, don't skip test case + + # Success case + print( + f"Benchmark completed successfully (PID: {benchmark_process.pid})" + ) + + # Add configuration summary to log file + config_summary = ( + f"Completed benchmark with Configuration: " + f"model_label={test_case['model']}, GPUs={test_case['gpus']}, " + f"TP={test_case['tp']}, EP={test_case['ep']}, " + f"attn_backend={test_case['attn_backend']}, " + f"moe_backend={test_case['moe_backend']}, " + f"enable_attention_dp={test_case['enable_attention_dp']}, " + f"free_gpu_mem_fraction={test_case['free_gpu_mem_fraction']}, " + f"max_batch_size={test_case['max_batch_size']}, " + f"ISL={test_case['isl']}, OSL={test_case['osl']}, " + f"max_num_tokens={test_case['max_num_tokens']}, " + f"moe_max_num_tokens={test_case['moe_max_num_tokens']}, " + f"Concurrency={concurrency}") + with open(benchmark_log_filename, 'a') as f: + f.write(f"\n{config_summary}\n") + + print("-----------------------------------------") + return True # Continue with next concurrency + + except Exception as e: + print( + f"Error running benchmark with concurrency {concurrency}: {e}") + print( + f"Skipping this concurrency level and continuing with next one..." + ) + print("-----------------------------------------") + return True # Continue with next concurrency, don't skip test case + + def run_test_case(self, test_case: Dict[str, Any]) -> None: + """Run a test case using the execution plan""" + model_label = test_case['model'] + test_case_id = test_case['id'] + + # Get model path + model_path = self.model_paths.get(model_label) + if not model_path: + print(f"Error: No model path found for {model_label}") + return + + # Use local path if it exists, otherwise use model name + if os.path.exists(model_path): + MODEL = model_path + else: + MODEL = model_label + + # Generate extra-llm-api-config.yml + config_content = self.generate_extra_llm_api_config(test_case) + config_path = "/tmp/extra-llm-api-config.yml" + + with open(config_path, 'w') as f: + f.write(config_content) + + print("extra-llm-api-config.yml:") + print(config_content) + + # Build trtllm-serve command + serve_cmd = [ + "trtllm-serve", MODEL, "--backend", "pytorch", "--tp_size", + str(test_case['tp']), "--ep_size", + str(test_case['ep']), "--max_batch_size", + str(test_case['max_batch_size']), "--max_num_tokens", + str(test_case['max_num_tokens']), + "--kv_cache_free_gpu_memory_fraction", + str(test_case['free_gpu_mem_fraction']), "--extra_llm_api_options", + config_path + ] + + print("Starting trtllm-serve with command:") + print(' '.join(serve_cmd)) + print() + + # Start server + server_log_filename = ( + f"trtllm-serve.{model_label}.tp{test_case['tp']}.ep{test_case['ep']}." + f"attn{test_case['attn_backend']}.moe{test_case['moe_backend']}." + f"gpu{test_case['free_gpu_mem_fraction']}.batch{test_case['max_batch_size']}." + f"isl{test_case['isl']}.osl{test_case['osl']}." + f"tokens{test_case['max_num_tokens']}.moetokens{test_case['moe_max_num_tokens']}.log" + ) + + try: + with open(server_log_filename, 'w') as log_file: + log_file.write(f"extra-llm-api-config.yml:\n") + log_file.write(config_content) + log_file.write("\n") + + with open(server_log_filename, 'a') as log_file: + server_process = subprocess.Popen(serve_cmd, + stdout=log_file, + stderr=subprocess.STDOUT) + + # Wait for server to be ready + if not self.wait_for_server(server_process.pid, + server_log_filename): + print( + "Failed to start server, killing process and skipping this test case" + ) + try: + subprocess.run(f"kill -9 {server_process.pid}", + shell=True, + check=False) + subprocess.run( + f"wait {server_process.pid} 2>/dev/null || true", + shell=True, + check=False) + except Exception as e: + print(f"Warning: Error during server cleanup: {e}") + return + + # Run benchmarks based on execution plan + for concurrency_index in self.execution_plan[test_case_id]: + concurrency, iteration = test_case['concurrency_iterations'][ + concurrency_index] + should_continue = self.run_benchmark(test_case, concurrency, + iteration, MODEL, + server_log_filename) + + # If run_benchmark returns False, skip the entire test case + if not should_continue: + print( + f"RuntimeError detected - skipping remaining concurrencies for test case {test_case_id}" + ) + break + + finally: + # Cleanup: Kill server process using shell commands like in the original bash script + print(f"Stopping server for {model_label}") + try: + # Use shell commands for more reliable process killing + subprocess.run(f"kill -9 {server_process.pid}", + shell=True, + check=False) + subprocess.run(f"wait {server_process.pid} 2>/dev/null || true", + shell=True, + check=False) + except Exception as e: + print(f"Warning: Error during server cleanup: {e}") + + time.sleep(5) # Give it time to clean up resources + print(f"Benchmark completed for {model_label}") + print() + + def run_benchmarks(self) -> None: + """Main function to run all benchmarks from config file""" + script_start_time = time.time() + + print(f"Using config file: {self.config_file}") + if self.select_pattern: + print(f"Select pattern: {self.select_pattern}") + else: + print("Select pattern: default (all test cases)") + if self.skip_pattern: + print(f"Skip pattern: {self.skip_pattern}") + else: + print("Skip pattern: default (no skipping)") + + # Load configuration + with open(self.config_file, 'r') as f: + config = yaml.safe_load(f) + + test_cases = config['test_cases'] + + # Build execution plan + self.build_execution_plan(test_cases) + + # Print execution plan before starting benchmarks + self.print_execution_plan(test_cases) + + # Run each test case based on execution plan + for i, test_case in enumerate(test_cases, 1): + test_case_id = test_case['id'] + + if test_case_id not in self.execution_plan: + print("=" * 57) + print( + f"Test case {i}/{len(test_cases)} (ID: {test_case_id}): {test_case['model']} - SKIPPED" + ) + print("=" * 57) + continue + + print("=" * 57) + print( + f"Test case {i}/{len(test_cases)} (ID: {test_case_id}): {test_case['model']}" + ) + print( + f"Config: GPUs={test_case['gpus']}, TP={test_case['tp']}, EP={test_case['ep']}, attn_backend={test_case['attn_backend']}, moe_backend={test_case['moe_backend']}" + ) + print("=" * 57) + + self.run_test_case(test_case) + + # Calculate and display total script runtime + script_total_time = time.time() - script_start_time + hours = int(script_total_time // 3600) + minutes = int((script_total_time % 3600) // 60) + seconds = int(script_total_time % 60) + + print("=" * 80) + print("SCRIPT COMPLETION SUMMARY") + print("=" * 80) + print( + f"Total script runtime: {hours:02d}:{minutes:02d}:{seconds:02d} (HH:MM:SS)" + ) + print(f"Total runtime in seconds: {script_total_time:.2f}") + print("=" * 80) + print("All benchmarks completed!") + + +def main(): + parser = argparse.ArgumentParser( + description='Run benchmarks from YAML configuration file') + parser.add_argument('--output_folder', + required=True, + help='Output folder for benchmark results') + parser.add_argument('--config_file', + required=True, + help='Path to YAML configuration file') + parser.add_argument( + '--skip', + help= + 'Skip pattern: "2,4-1" means skip test case 2 and test case 4\'s 1st concurrency' + ) + parser.add_argument( + '--select', + help= + 'Select pattern: "1,3,5" means only run test cases 1, 3, and 5; "1-1,2-3" means only run test case 1\'s 1st concurrency and test case 2\'s 3rd concurrency' + ) + + args = parser.parse_args() + + try: + subprocess.run(f'echo "TRT-LLM GIT COMMIT": $TRT_LLM_GIT_COMMIT', + shell=True, + check=True) + except subprocess.CalledProcessError: + print("Warning: Could not echo TRT-LLM GIT COMMIT") + + if not os.path.exists(args.config_file): + print(f"Error: Config file '{args.config_file}' does not exist") + sys.exit(1) + + if not os.path.exists(args.output_folder): + print(f"Error: Output folder '{args.output_folder}' does not exist") + sys.exit(1) + + try: + runner = BenchmarkRunner(args.output_folder, args.config_file, + args.skip, args.select) + runner.run_benchmarks() + except Exception as e: + print(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index ef3bf35a43..283a3eb8f0 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -11,7 +11,7 @@ from torch.fx import GraphModule from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory -from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ShardingTransformInfo +from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ShardingTransformInfo class FakeFactory(ModelFactory): diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index ab135aa28a..f47e38b994 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -6,16 +6,12 @@ import pytest import torch import torch.nn as nn from _dist_test_utils import get_device_counts -from _graph_test_helpers import run_sharding_pattern_detection_test, run_test +from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ( - BMMShardingInfo, - ShardingConfig, - detect_dp_bmm_shard, - sharding_transform_executor, -) +from tensorrt_llm._torch.auto_deploy.transform.library.sharding import BMMShardingInfo +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -64,22 +60,29 @@ def _run_job( num_experts = num_experts_multiplier * world_size model = BMM(num_experts, num_features).to(device="cuda", dtype=torch.float16) x = torch.randn(batch_size * num_experts, num_features, device="cuda", dtype=torch.float16) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "detect_dp_bmm_shard": { + "stage": "sharding", + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) def _get_expected_num_params(num_p_og: int) -> int: num_params = num_p_og // world_size return num_params - def transform_func(gm) -> None: - sharding_config = ShardingConfig() - detect_dp_bmm_shard(gm, rank, world_size, sharding_config) - sharding_transform_executor(gm, sharding_config) - # now run the test op_expected = getattr(torch.ops.auto_deploy, "torch_dist_all_gather") - run_test( + run_test_transformed_gm( model, x, - transform=transform_func, + gm_transformed, check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes) == (world_size > 1), _get_expected_num_params=_get_expected_num_params, @@ -118,9 +121,18 @@ def _run_pattern_detection_job( ) # get detected transformations - sharding_config = ShardingConfig() - detect_dp_bmm_shard(gm, rank, world_size, sharding_config) - detected_transformations = sharding_config.bmm_transforms + optimizer = InferenceOptimizer( + None, + { + "detect_dp_bmm_shard": { + "stage": "sharding", + }, + }, + ) + optimizer.shared_config.local_rank = rank + optimizer.shared_config.world_size = world_size + _ = optimizer(None, gm) + detected_transformations = optimizer.shared_config.sharding_config.bmm_transforms # Run pattern detection test run_sharding_pattern_detection_test(detected_transformations, expected_transformations) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 19cce48329..8a95771a3a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -5,17 +5,13 @@ from functools import partial import pytest import torch from _dist_test_utils import get_device_counts -from _graph_test_helpers import run_sharding_pattern_detection_test, run_test +from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm from _model_test_utils import MoEOpModel import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ( - EPShardingInfo, - ShardingConfig, - detect_ep_shard, - sharding_transform_executor, -) +from tensorrt_llm._torch.auto_deploy.transform.library.sharding import EPShardingInfo +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -39,17 +35,25 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None: expected_expert = num_experts_per_rank * hidden_size * intermediate_size * 3 return n_gate + expected_expert - def transform_func(gm) -> None: - sharding_config = ShardingConfig() - detect_ep_shard(gm, rank, world_size, sharding_config) - sharding_transform_executor(gm, sharding_config) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "detect_ep_shard": { + "stage": "sharding", + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) op_expected = torch.ops.auto_deploy.torch_dist_all_reduce - run_test( + run_test_transformed_gm( model, x, - transform=transform_func, + gm_transformed, check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes) == (world_size > 1), _get_expected_num_params=partial(_get_expected_num_params, rank, world_size), @@ -89,9 +93,18 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> ) # get detected transformations - sharding_config = ShardingConfig() - detect_ep_shard(gm, rank, world_size, sharding_config) - detected_transformations = sharding_config.ep_transforms + optimizer = InferenceOptimizer( + None, + { + "detect_ep_shard": { + "stage": "sharding", + }, + }, + ) + optimizer.shared_config.local_rank = rank + optimizer.shared_config.world_size = world_size + _ = optimizer(None, gm) + detected_transformations = optimizer.shared_config.sharding_config.ep_transforms # Run pattern detection test run_sharding_pattern_detection_test(detected_transformations, expected_transformations) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 9e33bef4a9..016dc65906 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -8,17 +8,15 @@ import torch import torch.nn as nn import torch.nn.functional as F from _dist_test_utils import get_device_counts -from _graph_test_helpers import run_sharding_pattern_detection_test, run_test +from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transformations.library import ( - ShardingConfig, +from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( SplitDimension, TPShardingInfo, - detect_column_row_shard, - sharding_transform_executor, ) +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op @@ -146,10 +144,18 @@ def _run_job( # now run the test op_expected = getattr(torch.ops.auto_deploy, dist_op_expected) - def transform_func(gm) -> None: - sharding_config = ShardingConfig() - detect_column_row_shard(gm, rank, world_size, sharding_config) - sharding_transform_executor(gm, sharding_config) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "detect_column_row_shard": { + "stage": "sharding", + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) def combined_graph_check(gm) -> bool: # Check for expected distributed operations @@ -160,10 +166,10 @@ def _run_job( weight_sizes_valid = verify_local_weight_sizes(gm) return has_expected_dist_ops and weight_sizes_valid - run_test( + run_test_transformed_gm( model, x, - transform=transform_func, + gm_transformed, check_transformed_graph=combined_graph_check, _get_expected_num_params=_get_expected_num_params, ) @@ -262,9 +268,18 @@ def _run_pattern_detection_job( ) # get detected transformations - sharding_config = ShardingConfig() - detect_column_row_shard(gm, rank, world_size, sharding_config) - detected_transformations = sharding_config.tp_transforms + optimizer = InferenceOptimizer( + None, + { + "detect_column_row_shard": { + "stage": "sharding", + }, + }, + ) + optimizer.shared_config.local_rank = rank + optimizer.shared_config.world_size = world_size + _ = optimizer(None, gm) + detected_transformations = optimizer.shared_config.sharding_config.tp_transforms # Run pattern detection test run_sharding_pattern_detection_test(detected_transformations, expected_transformations) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py index f25079e04b..9b2c813e54 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py @@ -1,3 +1,5 @@ +import json +import re import subprocess import tempfile from pathlib import Path @@ -5,11 +7,225 @@ from pathlib import Path import pytest import yaml from _model_test_utils import _hf_model_dir_or_hub_id -from click.testing import CliRunner from utils.cpp_paths import llm_root # noqa: F401 from utils.llm_data import llm_models_root -from tensorrt_llm.commands.bench import main + +def parse_kv_cache_metrics(log_output: str, free_mem_ratio: float = 0.8): + """Parse KV cache metrics from the benchmark log output.""" + metrics = {} + + # Simple patterns based on actual log format + patterns = { + "current_cache_size": r"Current cache size:\s*(\d+)", + "free_mem_pre_mb": r"Free memory before forward pass \(MB\):\s*(\d+)", + "free_mem_post_mb": r"Free memory after forward pass \(MB\):\s*(\d+)", + } + + # Extract metrics using simple regex patterns + for metric_name, pattern in patterns.items(): + match = re.search(pattern, log_output, re.IGNORECASE) + if match: + value = int(match.group(1)) + metrics[metric_name] = value + print(f" ✅ Found {metric_name}: {value}") + else: + print(f" ❌ Could not find {metric_name}") + + # Calculate new_cache_size using the same formula as in resize_kv_cache + # new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size + if "free_mem_post_mb" in metrics and "current_cache_size" in metrics: + metrics["new_cache_size"] = int( + metrics["free_mem_post_mb"] * 1024 * 1024 * free_mem_ratio + + metrics["current_cache_size"] + ) + print( + f" ✅ Calculated new_cache_size: {metrics['new_cache_size']} (using free_mem_ratio={free_mem_ratio})" + ) + else: + print(" ❌ Cannot calculate new_cache_size - missing required metrics") + + return metrics + + +def run_benchmark( + model_name: str, + dataset_path: str, + temp_dir: str, + backend: str = "_autodeploy", + report_json_path: str = None, + max_batch_size: int = 32, + num_hidden_layers: int = 2, + free_mem_ratio: float = 0.1, +): + """Run benchmark and capture KV cache metrics from log output.""" + + # Read the test config to get free_mem_ratio + config_path = f"{temp_dir}/extra_llm_api_options.yaml" + + # Build the command to run the benchmark + cmd = [ + "python", + "-m", + "tensorrt_llm.commands.bench", + "--model", + model_name, + "throughput", + "--backend", + backend, + "--dataset", + str(dataset_path), + "--max_batch_size", + str(max_batch_size), + ] + + # Add report_json argument if path is provided + if report_json_path: + cmd.extend(["--report_json", report_json_path]) + + if backend == "_autodeploy": + # Add extra_llm_api_options only for autodeploy backend + cmd.extend(["--extra_llm_api_options", config_path]) + + # Run benchmark as subprocess to capture ALL output + import os + + env = os.environ.copy() + if backend == "pytorch": + env["TLLM_OVERRIDE_LAYER_NUM"] = str(num_hidden_layers) + print(f"📋 Using TLLM_OVERRIDE_LAYER_NUM from env: {env['TLLM_OVERRIDE_LAYER_NUM']}") + cmd.extend(["--kv_cache_free_gpu_mem_fraction", str(free_mem_ratio)]) + print(f"🚀 Running benchmark command ({backend} backend): {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True, env=env, timeout=600) + + # Check if the command succeeded + assert result.returncode == 0, ( + f"Benchmark failed with return code {result.returncode}:\n" + f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}" + ) + + # Combine stdout and stderr for parsing + full_log_output = f"{result.stdout}\n{result.stderr}" + + # Parse KV cache metrics from the combined log output (only for autodeploy backend) + kv_cache_metrics = {} + if backend == "_autodeploy": + kv_cache_metrics = parse_kv_cache_metrics(full_log_output, free_mem_ratio) + print("📊 KV Cache Metrics parsed from logs:") + if kv_cache_metrics: + for key, value in kv_cache_metrics.items(): + if "mb" in key.lower(): + print(f" {key}: {value}MB") + else: + print(f" {key}: {value} bytes") + else: + print(" ⚠️ No KV cache metrics were parsed successfully") + else: + print(f"📊 KV Cache Metrics: Skipped for {backend} backend") + + # Return parsed JSON report with KV cache metrics if requested + if report_json_path and Path(report_json_path).exists(): + with open(report_json_path, "r") as f: + report_data = json.load(f) + + # Add KV cache metrics to the report (only for autodeploy backend) + if backend == "_autodeploy": + report_data["kv_cache_metrics"] = kv_cache_metrics + report_data["backend"] = backend + return report_data + return None + + +def compare_backends_performance( + autodeploy_tokens_per_sec: float, + pytorch_tokens_per_sec: float, + relative_tolerance: float = 0.20, + absolute_tolerance: float = 10.0, +): + """ + Compare performance between autodeploy and pytorch backends. + Fails if autodeploy is significantly worse than pytorch. + + Args: + autodeploy_tokens_per_sec: Performance of autodeploy backend + pytorch_tokens_per_sec: Performance of pytorch backend + relative_tolerance: Relative tolerance (20% by default for backend comparison) + absolute_tolerance: Absolute tolerance (10 tokens/sec by default) + """ + # Calculate performance difference + performance_diff = pytorch_tokens_per_sec - autodeploy_tokens_per_sec + relative_diff = performance_diff / pytorch_tokens_per_sec if pytorch_tokens_per_sec > 0 else 0 + + print("=== BACKEND PERFORMANCE COMPARISON ===") + print(f"PyTorch backend: {pytorch_tokens_per_sec:.2f} tokens/sec/user") + print(f"Autodeploy backend: {autodeploy_tokens_per_sec:.2f} tokens/sec/user") + print(f"Performance difference: {performance_diff:.2f} tokens/sec ({relative_diff:.2%})") + + # If autodeploy is better than or equal to pytorch, always pass + if autodeploy_tokens_per_sec >= pytorch_tokens_per_sec: + print("✅ Autodeploy backend matches or exceeds PyTorch backend performance") + return + + # Autodeploy is slower - check if it's within acceptable tolerance + within_relative_tolerance = relative_diff <= relative_tolerance + within_absolute_tolerance = performance_diff <= absolute_tolerance + + if within_relative_tolerance or within_absolute_tolerance: + print("✅ Autodeploy backend performance within acceptable tolerance") + print( + f" Tolerance: {relative_tolerance:.2%} relative OR {absolute_tolerance:.2f} tokens/sec absolute" + ) + else: + assert False, ( + f"Autodeploy backend significantly underperforms compared to PyTorch! " + f"Autodeploy: {autodeploy_tokens_per_sec:.2f} tokens/sec/user, " + f"PyTorch: {pytorch_tokens_per_sec:.2f} tokens/sec/user, " + f"Performance gap: {performance_diff:.2f} tokens/sec ({relative_diff:.2%}), " + f"Tolerance: {relative_tolerance:.2%} relative OR {absolute_tolerance:.2f} tokens/sec absolute" + ) + + +def assert_performance_within_tolerance( + actual_tokens_per_sec: float, + golden_tokens_per_sec: float, + relative_tolerance: float = 0.15, + absolute_tolerance: float = 10.0, +): + """ + Assert that actual performance is within tolerance of golden result. + Only fails if performance is WORSE than golden - improvements always pass. + + Args: + actual_tokens_per_sec: Measured performance metric + golden_tokens_per_sec: Expected performance metric + relative_tolerance: Relative tolerance (15% by default) + absolute_tolerance: Absolute tolerance (10 tokens/sec by default) + """ + # If actual performance is better than or equal to golden, always pass + if actual_tokens_per_sec >= golden_tokens_per_sec: + print( + f"✅ Performance improvement detected:" + f" {actual_tokens_per_sec:.2f} >= {golden_tokens_per_sec:.2f} tokens/sec/user" + ) + return + + # Performance is worse than golden - check if it's within acceptable tolerance + performance_drop = golden_tokens_per_sec - actual_tokens_per_sec + relative_drop = ( + performance_drop / golden_tokens_per_sec if golden_tokens_per_sec > 0 else float("inf") + ) + + # Performance should be within relative tolerance OR absolute tolerance + within_relative_tolerance = relative_drop <= relative_tolerance + within_absolute_tolerance = performance_drop <= absolute_tolerance + + assert within_relative_tolerance or within_absolute_tolerance, ( + f"Performance regression detected! " + f"Actual: {actual_tokens_per_sec:.2f} tokens/sec/user, " + f"Golden: {golden_tokens_per_sec:.2f} tokens/sec/user, " + f"Performance drop: {performance_drop:.2f} tokens/sec ({relative_drop:.2%}), " + f"Tolerance: {relative_tolerance:.2%} relative OR {absolute_tolerance:.2f} tokens/sec absolute" + ) def prepare_dataset(root_dir: str, temp_dir: str, model_name: str): @@ -18,7 +234,7 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_name: str): dataset_tool = Path(root_dir, "benchmarks", "cpp", "prepare_dataset.py") script_dir = Path(root_dir, "benchmarks", "cpp") - # Generate a small dataset to run a test. + # Generate a small dataset to run a test - matching workload configuration command = [ "python3", f"{dataset_tool}", @@ -38,7 +254,9 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_name: str): "10", ] print(f"Running command: {' '.join(command)}") - result = subprocess.run(command, cwd=str(script_dir), capture_output=True, text=True) + result = subprocess.run( + command, cwd=str(script_dir), capture_output=True, text=True, timeout=300 + ) if result.returncode != 0: raise RuntimeError(f"Failed to prepare dataset: {result.stderr}") # Grab the stdout and write it to a dataset file for passing to suite. @@ -47,22 +265,324 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_name: str): return dataset_path -def run_benchmark(model_name: str, dataset_path: str, temp_dir: str): - runner = CliRunner() +def calculate_expected_kv_cache_metrics(free_mem_ratio: float): + """Calculate expected KV cache metrics based on actual GPU memory.""" + try: + import torch - args = [ - "--model", - model_name, - "throughput", - "--backend", - "_autodeploy", - "--dataset", - dataset_path, - "--extra_llm_api_options", - f"{temp_dir}/model_kwargs.yaml", + if torch.cuda.is_available(): + # Get total GPU memory in MB + _, total_mem_bytes = torch.cuda.mem_get_info(0) + total_mem_mb = total_mem_bytes // (1024 * 1024) + + # Estimate expected values based on model size + # For TinyLlama-1.1B, model should be 2.2GB + estimated_model_size_mb = 2200 # Conservative estimate + # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/6335 check why there is extra consumption + extra_consumption_mb = 2500 + expected_free_mem_range = ( + total_mem_mb - estimated_model_size_mb - extra_consumption_mb, + total_mem_mb - estimated_model_size_mb, + ) + + # Current cache size is typically small initially (16MB range) + expected_current_cache_size = 16777216 + + # Free memory values should be in reasonable range + expected_free_mem_pre_range = expected_free_mem_range + expected_free_mem_post_range = ( + expected_free_mem_range[0] - 1000, + expected_free_mem_range[1] - 500, + ) + + print("📊 GPU Memory Analysis:") + print(f" Total GPU memory: {total_mem_mb}MB") + print( + f" Expected free memory range: {expected_free_mem_range[0]}-{expected_free_mem_range[1]}MB" + ) + + return { + "total_mem_mb": total_mem_mb, + "expected_current_cache_size": expected_current_cache_size, + "expected_free_mem_pre_range": expected_free_mem_pre_range, + "expected_free_mem_post_range": expected_free_mem_post_range, + "free_mem_ratio": free_mem_ratio, + } + else: + return None + except ImportError: + return None + + +def validate_kv_cache_metrics_dynamic(kv_cache_metrics: dict, expected_metrics: dict): + """Validate KV cache metrics using dynamic expected values.""" + + # Validate current_cache_size (should be relatively stable) + current_cache_size = kv_cache_metrics.get("current_cache_size") + expected_cache_size = expected_metrics["expected_current_cache_size"] + if current_cache_size: + cache_diff = abs(current_cache_size - expected_cache_size) / expected_cache_size + assert cache_diff <= 0.5, ( # 50% tolerance for cache size + f"Current cache size outside expected range: {current_cache_size} vs expected ~{expected_cache_size}" + ) + print(f" ✅ current_cache_size: {current_cache_size} bytes (within range)") + + # Validate free memory values are in reasonable ranges + free_mem_pre = kv_cache_metrics.get("free_mem_pre_mb") + free_mem_post = kv_cache_metrics.get("free_mem_post_mb") + + if free_mem_pre: + pre_range = expected_metrics["expected_free_mem_pre_range"] + assert pre_range[0] <= free_mem_pre <= pre_range[1], ( + f"Free memory before forward pass outside expected range: " + f"{free_mem_pre}MB not in range {pre_range[0]}-{pre_range[1]}MB" + ) + print(f" ✅ free_mem_pre_mb: {free_mem_pre}MB (within range)") + + if free_mem_post: + post_range = expected_metrics["expected_free_mem_post_range"] + assert post_range[0] <= free_mem_post <= post_range[1], ( + f"Free memory after forward pass outside expected range: " + f"{free_mem_post}MB not in range {post_range[0]}-{post_range[1]}MB" + ) + print(f" ✅ free_mem_post_mb: {free_mem_post}MB (within range)") + + # Validate memory consumption (pre should be > post) + if free_mem_pre and free_mem_post: + memory_consumed = free_mem_pre - free_mem_post + assert memory_consumed > 0, ( + f"Expected memory consumption during forward pass, got {memory_consumed}MB" + ) + assert memory_consumed < 5000, f"Memory consumption too high: {memory_consumed}MB" + print(f" ✅ Memory consumed during forward pass: {memory_consumed}MB (reasonable)") + + # Validate calculated new_cache_size + new_cache_size = kv_cache_metrics.get("new_cache_size") + if new_cache_size and free_mem_post and current_cache_size: + expected_new_cache = int( + free_mem_post * 1024 * 1024 * expected_metrics["free_mem_ratio"] + current_cache_size + ) + cache_size_diff = abs(new_cache_size - expected_new_cache) / expected_new_cache + assert cache_size_diff <= 0.01, ( # 1% tolerance for calculated value + f"Calculated new_cache_size mismatch: {new_cache_size} vs expected {expected_new_cache}" + ) + print(f" ✅ new_cache_size: {new_cache_size} bytes (calculation correct)") + + +def extract_performance_metric(report_data, report_name="benchmark"): + """Extract performance metric from a benchmark report with validation.""" + assert report_data is not None, f"Failed to capture {report_name} report" + assert "performance" in report_data, f"Performance metrics not found in {report_name} report" + + tokens_per_sec = report_data["performance"].get("output_throughput_per_user_tok_s") + assert tokens_per_sec is not None, ( + f"output_throughput_per_user_tok_s not found in {report_name} performance metrics" + ) + + return tokens_per_sec + + +def validate_and_extract_kv_cache_metrics(report_data, free_mem_ratio, require_metrics=True): + """ + Validate and extract KV cache metrics from report. + + Args: + report_data: The benchmark report data + free_mem_ratio: Free memory ratio for calculating expected metrics + require_metrics: If True, fail when metrics are missing. If False, just warn. + + Returns: + Tuple of (kv_cache_metrics, expected_metrics) or (None, None) if validation fails + """ + required_metrics = [ + "current_cache_size", + "free_mem_pre_mb", + "free_mem_post_mb", + "new_cache_size", ] - result = runner.invoke(main, args, catch_exceptions=False) - assert result.exit_code == 0 + + # Extract KV cache metrics + kv_cache_metrics = report_data.get("kv_cache_metrics", {}) + + if not kv_cache_metrics: + message = ( + "KV cache metrics not found! " + "The autodeploy backend must log memory statistics for this test to pass. " + f"Expected metrics: {', '.join(required_metrics)}" + ) + if require_metrics: + assert False, f"REQUIRED {message}" + else: + print(f"ℹ️ {message}") + assert False, "KV cache metrics are missing" + + # Check for missing metrics + missing_metrics = [metric for metric in required_metrics if metric not in kv_cache_metrics] + + if missing_metrics: + message = ( + f"Missing required KV cache metrics: {missing_metrics}. " + f"Found metrics: {list(kv_cache_metrics.keys())}. " + f"All of {required_metrics} are required for the test to pass." + ) + if require_metrics: + assert False, message + else: + print(f"ℹ️ KV cache validation skipped - {message}") + assert False, "KV cache metrics are missing" + + # Calculate expected metrics + expected_metrics = calculate_expected_kv_cache_metrics(free_mem_ratio) + assert expected_metrics, "Could not determine expected metrics for this GPU" + + return kv_cache_metrics, expected_metrics + + +def print_kv_cache_metrics(kv_cache_metrics): + """Print KV cache metrics in a formatted way.""" + print("=== KV CACHE METRICS (DYNAMIC VALIDATION) ===") + for metric_name, actual_value in kv_cache_metrics.items(): + if "mb" in metric_name.lower(): + print(f"{metric_name}: {actual_value}MB") + else: + print(f"{metric_name}: {actual_value} bytes") + + +def trtllm_bench_unified_comparison( + llm_root, # noqa: F811 + comparison_mode="backend", + free_mem_ratio=0.1, + num_hidden_layers=2, + max_batch_size=32, # below this value the kv cache resizing is skipped + golden_tokens_per_sec=1400, + backend_relative_tolerance=0.2, + backend_absolute_tolerance=250.0, + golden_relative_tolerance=0.1, + golden_absolute_tolerance=5.0, +): + """ + Unified test that compares autodeploy backend performance in two modes: + - "backend": compares against pytorch backend performance + - "golden": compares against predefined golden performance values + + Args: + llm_root: Root directory for LLM models (pytest fixture) + comparison_mode: Either "backend" or "golden" to determine comparison type + free_mem_ratio: Ratio of free memory to use for KV cache + num_hidden_layers: Number of hidden layers for the model + max_batch_size: Maximum batch size for benchmarking + golden_tokens_per_sec: Golden performance value in tokens/sec/user + backend_relative_tolerance: Relative tolerance for backend comparison + backend_absolute_tolerance: Absolute tolerance for backend comparison + golden_relative_tolerance: Relative tolerance for golden comparison + golden_absolute_tolerance: Absolute tolerance for golden comparison + """ + model_name = _hf_model_dir_or_hub_id( + f"{llm_models_root()}/TinyLlama-1.1B-Chat-v1.0", "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + ) + + with tempfile.TemporaryDirectory() as temp_dir: + with open(f"{temp_dir}/extra_llm_api_options.yaml", "w") as f: + yaml.dump( + { + "model_kwargs": {"num_hidden_layers": num_hidden_layers}, + "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32], + "compile_backend": "torch-opt", + "free_mem_ratio": free_mem_ratio, + "runtime": "trtllm", + }, + f, + ) + + dataset_path = prepare_dataset(llm_root, temp_dir, model_name) + + # Always run autodeploy backend + autodeploy_report_path = f"{temp_dir}/autodeploy_report.json" + print("=== RUNNING AUTODEPLOY BACKEND ===") + autodeploy_report = run_benchmark( + model_name, + dataset_path, + temp_dir, + "_autodeploy", + autodeploy_report_path, + max_batch_size, + num_hidden_layers, + free_mem_ratio, + ) + + # Extract autodeploy performance metrics + autodeploy_tokens_per_sec = extract_performance_metric(autodeploy_report, "autodeploy") + + # Validate and extract KV cache metrics (now required for both modes after user's changes) + kv_cache_metrics, expected_metrics = validate_and_extract_kv_cache_metrics( + autodeploy_report, free_mem_ratio, require_metrics=True + ) + + if comparison_mode == "backend": + # Backend comparison mode: also run pytorch backend + pytorch_report_path = f"{temp_dir}/pytorch_report.json" + print("=== RUNNING PYTORCH BACKEND ===") + pytorch_report = run_benchmark( + model_name, + dataset_path, + temp_dir, + "pytorch", + pytorch_report_path, + max_batch_size, + num_hidden_layers, + free_mem_ratio, + ) + + # Extract pytorch performance metrics + pytorch_tokens_per_sec = extract_performance_metric(pytorch_report, "pytorch") + + # Compare backend performance + compare_backends_performance( + autodeploy_tokens_per_sec, + pytorch_tokens_per_sec, + relative_tolerance=backend_relative_tolerance, + absolute_tolerance=backend_absolute_tolerance, + ) + + # Validate KV cache metrics + validate_kv_cache_metrics_dynamic(kv_cache_metrics, expected_metrics) + print("✅ KV Cache Metrics validation passed") + + print("=== BACKEND COMPARISON TEST PASSED ===") + print(f"Autodeploy: {autodeploy_tokens_per_sec:.2f} tokens/sec/user") + print(f"PyTorch: {pytorch_tokens_per_sec:.2f} tokens/sec/user") + + elif comparison_mode == "golden": + # Golden comparison mode: compare against golden values + print("=== PERFORMANCE METRICS ===") + print(f"Measured performance: {autodeploy_tokens_per_sec:.2f} tokens/sec/user") + print(f"Golden performance: {golden_tokens_per_sec:.2f} tokens/sec/user") + + # Print KV cache metrics + print_kv_cache_metrics(kv_cache_metrics) + + # Performance validation + assert_performance_within_tolerance( + autodeploy_tokens_per_sec, + golden_tokens_per_sec, + relative_tolerance=golden_relative_tolerance, + absolute_tolerance=golden_absolute_tolerance, + ) + + # KV cache metrics validation + print( + f"Validating {len(kv_cache_metrics)} KV cache metrics against GPU-specific ranges..." + ) + validate_kv_cache_metrics_dynamic(kv_cache_metrics, expected_metrics) + + print("=== ALL TESTS PASSED ===") + print(f"Performance: ✅ {autodeploy_tokens_per_sec:.2f} tokens/sec/user within bounds") + print("KV Cache Metrics: ✅ All metrics within GPU-specific expected ranges") + + else: + raise ValueError( + f"Invalid comparison_mode: {comparison_mode}. Must be 'backend' or 'golden'" + ) @pytest.mark.skip("https://nvbugswb.nvidia.com/NVBugs5/redir.aspx?url=/5443039") @@ -72,15 +592,20 @@ def test_trtllm_bench(llm_root): # noqa: F811 ) with tempfile.TemporaryDirectory() as temp_dir: - with open(f"{temp_dir}/model_kwargs.yaml", "w") as f: + with open(f"{temp_dir}/extra_llm_api_options.yaml", "w") as f: yaml.dump( { "model_kwargs": {"num_hidden_layers": 2}, "cuda_graph_batch_sizes": [1, 2], - "max_batch_size": 128, }, f, ) dataset_path = prepare_dataset(llm_root, temp_dir, model_name) run_benchmark(model_name, dataset_path, temp_dir) + + +@pytest.mark.no_xdist +def test_trtllm_bench_backend_comparison(llm_root): # noqa: F811 + """Test that compares autodeploy backend performance against pytorch backend.""" + trtllm_bench_unified_comparison(llm_root, comparison_mode="backend") diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py index a813e9906a..93cfec18b5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py @@ -78,6 +78,9 @@ def _joint_transform(gm: GraphModule) -> None: ["eager", "sdpa"], ) def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str): + if attn_implementation == "sdpa": + pytest.skip("https://nvbugspro.nvidia.com/bug/5170222") + def verify_matcher(gm: GraphModule): """Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa call in the graph. Also check that there is no repeat_kv pattern left. diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index 876eba196c..f2fd32ea3e 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -187,7 +187,7 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): # Helper function to call the model with proper sequence nesting def _call_and_unnest(x): # Use nest_sequences to properly set input_ids and automatically update position_ids - cm.info.nest_sequences(x) + cm.info.nest_sequences(x, allow_realloc=True) # Use the cm.args as is - it already contains the correct position_ids y = gm(*cm.args) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py index c937d11211..dba864f4f1 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py @@ -2,15 +2,12 @@ import pytest import torch import torch.nn as nn import torch.nn.functional as F -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from _model_test_utils import MoEOpModel from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available -import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 -from tensorrt_llm._torch.auto_deploy.transformations.library.fused_moe import ( - fuse_moe, - match_moe_pattern, -) +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale @@ -304,11 +301,20 @@ def test_moe_matching(quant_type, expected_op, atol, rtol): model.block_sparse_moe.gate = model.block_sparse_moe.gate.to(dtype=torch.bfloat16) x = model.get_input(device=device) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "match_moe_pattern": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) - _ = run_test( + run_test_transformed_gm( model, x, - match_moe_pattern, + gm_transformed, lambda gm: any(is_op(n, expected_op) for n in gm.graph.nodes), lambda num: num, atol=atol, @@ -322,11 +328,20 @@ def test_moe_fusion(): device = "cuda" model = MoEOpModel().to(device=device, dtype=torch.bfloat16) x = model.get_input(device=device, dtype=torch.bfloat16) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_moe": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) - fused_gm_transformed = run_test( + run_test_transformed_gm( model, x, - fuse_moe, + gm_transformed, lambda gm: any( is_op( n, {torch.ops.auto_deploy.torch_moe_fused, torch.ops.auto_deploy.trtllm_moe_fused} @@ -342,7 +357,7 @@ def test_moe_fusion(): # expert weights are fused and stacked in fusion num_param_nodes = len(list(model.named_parameters())) - num_param_nodes_fused = len(list(fused_gm_transformed.named_parameters())) + num_param_nodes_fused = len(list(gm_transformed.named_parameters())) assert ( num_param_nodes_fused < num_param_nodes ), f"""number of parameter nodes after fusion {num_param_nodes_fused} < diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py index 6f2734bc6c..341edae905 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py @@ -73,7 +73,7 @@ def test_quantization(quant_config, atol, rtol, num_p_og): gm_transformed = InferenceOptimizer( DummyFactory(quant_config), { - "quantize": { + "quantize_from_config": { "stage": "pattern_matcher", }, }, @@ -155,7 +155,7 @@ def test_bmm_quantization(quant_config, atol, rtol, num_p_og, model_class): gm_transformed = InferenceOptimizer( DummyFactory(quant_config), { - "quantize": { + "quantize_from_config": { "stage": "pattern_matcher", }, }, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py index 644086cdf3..4458337cf7 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py @@ -3,14 +3,15 @@ import pytest import torch import torch.nn as nn -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from torch.fx import GraphModule -from tensorrt_llm._torch.auto_deploy.transformations.library.eliminate_redundant_transposes import ( +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.library.eliminate_redundant_transposes import ( _is_contiguous_op, _is_transpose_op, - eliminate_redundant_transposes, ) +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer class RedundantTransposeModel(nn.Module): @@ -217,16 +218,25 @@ def test_eliminate_redundant_transposes_with_contiguous(model_class): # Setup model and input model = model_class().cuda() x = torch.randn(2, 3, 4).cuda() + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "eliminate_redundant_transposes": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) # Create a check function for this specific model expected_transpose_count = model.expected_remaining_transposes expected_contiguous_count = model.expected_remaining_contiguous # Run the test using the helper - run_test( + run_test_transformed_gm( model=model, x=x, - transform=eliminate_redundant_transposes, + gm_transformed=gm_transformed, check_transformed_graph=lambda gm: check_transpose_and_contiguous_count( gm, expected_transpose_count, expected_contiguous_count ), diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py index 436087e847..4a15eddfa5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py @@ -1,6 +1,6 @@ import pytest import torch -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from _model_test_utils import ( apply_rotary_pos_emb_complex, apply_rotary_pos_emb_ds, @@ -8,11 +8,8 @@ from _model_test_utils import ( ) from torch.export import Dim -from tensorrt_llm._torch.auto_deploy.transformations.library.rope import ( - match_rope_layout, - match_rope_pattern, - optimize_rope, -) +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import extract_output_tuple, is_op torch.manual_seed(0) @@ -212,9 +209,19 @@ def test_rope_variants( ).to("cuda", torch.float16) x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) dyn = model.get_dynamic_shapes() + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dyn,), clone=True) if transformation == "match": - fn = match_rope_pattern + gm_transformed = InferenceOptimizer( + None, + { + "match_rope_pattern": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + gm_transformed.to("cuda") + check_op = ( torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin if variant == "explicit" or variant == "explicit_pm" @@ -224,8 +231,32 @@ def test_rope_variants( def checker(gm): return any(is_op(n, check_op) for n in gm.graph.nodes) + run_test_transformed_gm( + model, + x, + gm_transformed, + checker, + lambda n: n, + atol, # atol + rtol, # rtol + True, # test_load_hook + True, # strict_loading + dyn, # dynamic_shapes + 1, # check_num_matches + False, # skip_output_assert + ) + elif transformation == "match_layout": - fn = match_rope_layout + gm_transformed = InferenceOptimizer( + None, + { + "match_rope_layout": { + "stage": "pattern_matcher", + "expected_layout": target_layout, + }, + }, + )(None, gm) + gm_transformed.to("cuda") def checker(gm): for n in gm.graph.nodes: @@ -254,17 +285,10 @@ def test_rope_variants( return matched if layout != target_layout else not matched - else: - fn = optimize_rope - - def checker(gm): - return any(is_op(n, torch.ops.auto_deploy.flashinfer_rope) for n in gm.graph.nodes) - - if transformation == "match_layout": - _ = run_test( + run_test_transformed_gm( model, x, - fn, + gm_transformed, checker, lambda n: n, atol, # atol @@ -274,28 +298,26 @@ def test_rope_variants( dyn, # dynamic_shapes None, # check_num_matches False, # skip_output_assert - target_layout, ) - elif transformation == "match": - _ = run_test( + + else: # optimize + gm_transformed = InferenceOptimizer( + None, + { + "optimize_rope": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + gm_transformed.to("cuda") + + def checker(gm): + return any(is_op(n, torch.ops.auto_deploy.flashinfer_rope) for n in gm.graph.nodes) + + run_test_transformed_gm( model, x, - fn, - checker, - lambda n: n, - atol, # atol - rtol, # rtol - True, # test_load_hook - True, # strict_loading - dyn, # dynamic_shapes - 1, # check_num_matches - False, # skip_output_assert - ) - else: - _ = run_test( - model, - x, - fn, + gm_transformed, checker, lambda n: n, atol, # atol @@ -388,9 +410,18 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target x = torch.randn(batch, seq, hid, device="cuda", dtype=torch.float16) dynamic_shapes = model.get_dynamic_shapes() + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) if mode == "match": - transform = match_rope_pattern + gm_transformed = InferenceOptimizer( + None, + { + "match_rope_pattern": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + gm_transformed.to("cuda") def checker(gm): return any( @@ -398,8 +429,32 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target for n in gm.graph.nodes ) + run_test_transformed_gm( + model, + x, + gm_transformed, + checker, + lambda num_p: num_p, + 1e-3, # atol + 1e-3, # rtol + True, # test_load_hook + True, # strict_loading + dynamic_shapes, # dynamic_shapes + 1, # check_num_matches + False, # skip_output_assert + ) + else: # mode == "match_layout" - transform = match_rope_layout + gm_transformed = InferenceOptimizer( + None, + { + "match_rope_layout": { + "stage": "pattern_matcher", + "expected_layout": target_layout, + }, + }, + )(None, gm) + gm_transformed.to("cuda") def checker(gm): for n in gm.graph.nodes: @@ -422,11 +477,10 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target return matched if layout != target_layout else not matched - if mode == "match_layout": - _ = run_test( + run_test_transformed_gm( model, x, - transform, + gm_transformed, checker, lambda num_p: num_p, 1e-3, # atol @@ -436,20 +490,4 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target dynamic_shapes, # dynamic_shapes None, # check_num_matches False, # skip_output_assert - target_layout, - ) - else: - _ = run_test( - model, - x, - transform, - checker, - lambda num_p: num_p, - 1e-3, # atol - 1e-3, # rtol - True, # test_load_hook - True, # strict_loading - dynamic_shapes, # dynamic_shapes - 1, # check_num_matches - False, # skip_output_assert ) diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 5e9f2ba1a2..4b63769735 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -75,6 +75,19 @@ def calc_diff(x, y): return 1 - sim +def calc_woq_tolerence(x: torch.Tensor, weight_dtype: torch.dtype): + # align with woq_assert_near_eq function in tests/unittest/trt/quantization/_utils.py + if weight_dtype == torch.int8: + bits_in_type = 8 + elif weight_dtype == torch.quint4x2: + bits_in_type = 4 + quant_range_scale = 1.0 / float(1 << (bits_in_type - 1)) + max_val = torch.max(abs(x)).item() + atol = (max_val * quant_range_scale) * 1.5 # allow for rounding + + return atol + + def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor, final_scales: torch.Tensor, num_experts: int, weights: Dict[str, torch.Tensor]) -> torch.Tensor: diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4.py b/tests/unittest/_torch/modeling/test_modeling_exaone4.py index 6b907db0a5..48ad7d8d83 100644 --- a/tests/unittest/_torch/modeling/test_modeling_exaone4.py +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4.py @@ -65,7 +65,7 @@ EXAONE4_SINGLE_LAYER_CONFIG = { }, "rope_theta": 1000000, "sliding_window": 4, # NOTE: For testing, we use 4 instead of 4096 - "sliding_window_pattern": "LLLG", + "sliding_window_pattern": 4, "tie_word_embeddings": False, "torch_dtype": "bfloat16", "transformers_version": "4.54.0.dev0", diff --git a/tests/unittest/_torch/modeling/test_modeling_gpt_oss.py b/tests/unittest/_torch/modeling/test_modeling_gpt_oss.py new file mode 100644 index 0000000000..43de8f81b5 --- /dev/null +++ b/tests/unittest/_torch/modeling/test_modeling_gpt_oss.py @@ -0,0 +1,89 @@ +import json +import os +import shutil + +import pytest + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ + IS_TRITON_KERNELS_AVAILABLE +from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MoeConfig + +configs = """ +{ + "architectures": [ + "GptOssForCausalLM" + ], + "model_type": "gpt_oss", + "torch_dtype": "bfloat16", + "num_hidden_layers": 4, + "num_experts": 128, + "experts_per_token": 4, + "vocab_size": 201088, + "hidden_size": 2880, + "intermediate_size": 2880, + "head_dim": 64, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "sliding_window": 128, + "initial_context_length": 4096, + "rope_theta": 150000, + "rope_scaling_factor": 32.0, + "rope_ntk_alpha": 1, + "rope_ntk_beta": 32 +} +""" + + +def dump_config_json(dst_dir): + if os.path.exists(dst_dir): + shutil.rmtree(dst_dir) + os.makedirs(dst_dir) + + dst_path = os.path.join(dst_dir, 'config.json') + with open(dst_path, 'w', encoding='utf-8') as f: + json_configs = json.loads(configs) + json.dump(json_configs, f, indent=2, ensure_ascii=False) + + +@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON"]) +def test_gpt_oss_trtllmgen(moe_backend): + if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + + pytest.skip("https://nvbugspro.nvidia.com/bug/5441721") + + prompts = [ + "How are you?", + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(), + attn_backend="TRTLLM", + load_format="dummy", + moe_config=MoeConfig(backend=moe_backend), + ) + + tmp_model_dir = f"/tmp/test_model_trtllm" + + dump_config_json(tmp_model_dir) + + llm = LLM(model=tmp_model_dir, + tensor_parallel_size=1, + enable_chunked_prefill=False, + **pytorch_config, + max_batch_size=16, + max_seq_len=1024, + moe_expert_parallel_size=-1, + moe_tensor_parallel_size=-1, + enable_attention_dp=False, + kv_cache_config=KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.4)) + + sampling_params = SamplingParams(max_tokens=20) + llm.generate(prompts, sampling_params) diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 04ca1cd62f..a0d09c18c7 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -3,6 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import torch +import transformers from parameterized import parameterized from transformers import Llama4Config from transformers import \ @@ -265,6 +266,11 @@ class TestLlama4MinLatency(unittest.TestCase): attention_backend = "TRTLLM" metadata_cls = get_attention_backend(attention_backend).Metadata + if transformers.__version__ >= "4.55.0": + self.skipTest( + "The transformers 4.55.0 has accuracy issues while 4.33.1 works fine. " + "https://nvbugspro.nvidia.com/bug/5441729") + torch.random.manual_seed(0) config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG) # 17B * sizeof(float16) plus some extra for activations diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index bb2b8a5b39..9117f39c7d 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -1,3 +1,4 @@ +import pytest import torch from utils.llm_data import llm_models_root from utils.util import skip_gpu_memory_less_than @@ -29,8 +30,10 @@ def extract_decode_logprobs(result: RequestOutput, return get_logprobs(token_ids, logits) -def create_nemotron_h_llm(use_cuda_graph, disable_overlap_scheduler, - max_batch_size): +def create_nemotron_h_llm(use_cuda_graph, + disable_overlap_scheduler, + max_batch_size, + mamba_ssm_cache_dtype=None): """Create LLM with specific overlap scheduler setting""" model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K" return LLM( @@ -39,14 +42,19 @@ def create_nemotron_h_llm(use_cuda_graph, disable_overlap_scheduler, max_batch_size=max_batch_size, cuda_graph_config=CudaGraphConfig() if use_cuda_graph else None, disable_overlap_scheduler=disable_overlap_scheduler, - kv_cache_config=KvCacheConfig(enable_block_reuse=False), - enable_trtllm_sampler=True, + kv_cache_config=KvCacheConfig( + enable_block_reuse=False, + mamba_ssm_cache_dtype="auto" + if mamba_ssm_cache_dtype is None else mamba_ssm_cache_dtype), + sampler_type="TRTLLMSampler", ) @skip_gpu_memory_less_than( (2 * 8 + 1) * 2**30) # 8B, bf16, plus 1 GB for good measure -def test_nemotron_h_correctness(): +@pytest.mark.parametrize("mamba_ssm_cache_dtype", [None, "float32"], + ids=lambda n: f"mamba_ssm_cache_dtype:{n}") +def test_nemotron_h_correctness(mamba_ssm_cache_dtype): # This test is close to memory limit on A30 (with 24GB), so empty cache first torch.cuda.empty_cache() @@ -56,9 +64,11 @@ def test_nemotron_h_correctness(): ] num_prompts = len(text_prompts) - nemotron_h = create_nemotron_h_llm(use_cuda_graph=False, - disable_overlap_scheduler=False, - max_batch_size=num_prompts) + nemotron_h = create_nemotron_h_llm( + use_cuda_graph=False, + disable_overlap_scheduler=False, + max_batch_size=num_prompts, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) expected_completions = [ " bright, with endless possibilities for innovation and growth", @@ -237,6 +247,7 @@ def test_nemotron_h_correctness(): nemotron_h.shutdown() +@pytest.mark.skip(reason="https://nvbugs/5458874") def test_nemotron_h_cuda_graph_overlap_scheduler(): prompts = [ "The sky is blue because", diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_nas.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_nas.py index f5e395401d..9f4ed2df0f 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_nas.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_nas.py @@ -357,6 +357,7 @@ class TestNemotronNAS(unittest.TestCase): ], lambda testcase_func, param_num, param: f"{testcase_func.__name__}[{param.args[0]}]") @torch.no_grad() + @unittest.skip("https://nvbugspro.nvidia.com/bug/5439817") def test_nemotron_nas_allclose_to_hf(self, scenario: Scenario) -> None: """ Compare output to HF diff --git a/tests/unittest/_torch/modeling/test_modeling_pixtral.py b/tests/unittest/_torch/modeling/test_modeling_pixtral.py index f47a0d4b11..50ebe28f35 100644 --- a/tests/unittest/_torch/modeling/test_modeling_pixtral.py +++ b/tests/unittest/_torch/modeling/test_modeling_pixtral.py @@ -28,8 +28,7 @@ mpi4py.MPI.pickle.__init__( pytestmark = pytest.mark.threadleak(enabled=False) -@pytest.fixture -def pixtral_vision_config(): +def make_pixtral_vision_config(): # Values taken from: # https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/config.json return model_config_lib.ModelConfig( @@ -71,9 +70,10 @@ def init_hf_model(cls, config, dtype, device): @torch.no_grad() @pytest.mark.usefixtures("set_seed") -def test_pixtral_vision_model_vs_hf(pixtral_vision_config): +def test_pixtral_vision_model_vs_hf(): dtype = torch.bfloat16 device = torch.device("cuda") + pixtral_vision_config = make_pixtral_vision_config() pretrained_config = pixtral_vision_config.pretrained_config pixtral_model = ( @@ -111,13 +111,14 @@ def test_pixtral_vision_model_vs_hf(pixtral_vision_config): @pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) @torch.no_grad() -def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path): +def test_tensor_parallelism(mpi_pool_executor, tmp_path): mapping = mapping_lib.Mapping(world_size=2, tp_size=2) if (num_available_devices := torch.cuda.device_count()) < mapping.world_size: pytest.skip(f"{num_available_devices=} is less than the requested {mapping.world_size}.") dtype = torch.bfloat16 device = torch.device("cuda") + pixtral_vision_config = make_pixtral_vision_config() pretrained_config = pixtral_vision_config.pretrained_config hf_pixtral_model = init_hf_model( @@ -157,20 +158,22 @@ def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path): gc.collect() torch.cuda.empty_cache() + # NOTE: we cannot send `pixtral_vision_config` across the process barrier, as it contains + # `weakref` objects, which cannot be pickled. Instead, each worker will recreate it by + # calling the `make_pixtral_vision_config` function. world_size = mapping.world_size - pixtral_vision_config.mapping = mapping results = mpi_pool_executor.starmap( _run_pixtral_and_compare_against_ref, [ ( - pixtral_vision_config, + mapping_lib.Mapping(tp_size=world_size, world_size=world_size, rank=rank), hf_weights_path, pixel_values, image_sizes, ref_out, num_params, ) - for _ in range(world_size) + for rank in range(world_size) ], ) @@ -179,7 +182,7 @@ def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path): def _run_pixtral_and_compare_against_ref( - pixtral_vision_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig], + mapping: mapping_lib.Mapping, hf_weights_path: pathlib.Path, pixel_values: torch.Tensor, image_sizes: torch.Tensor, @@ -197,7 +200,8 @@ def _run_pixtral_and_compare_against_ref( image_sizes = image_sizes.to("cuda") expected_output = expected_output.to("cuda") - pixtral_vision_config.mapping.rank = rank + pixtral_vision_config = make_pixtral_vision_config() + pixtral_vision_config.mapping = mapping pixtral_model = ( modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config).eval().to("cuda") ) diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index a72ad4c7b6..2d11971d99 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -9,21 +9,17 @@ import cloudpickle import pytest import torch import torch.nn as nn -from _torch.helpers import (per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0, +from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8, + per_block_cast_to_fp8_e8m0, per_token_cast_to_fp8_e8m0) from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor -from utils.util import (skip_neither_ada_nor_hopper_unittest, +from utils.util import (check_accuracy, skip_neither_ada_nor_hopper_unittest, skip_non_hopper_unittest, skip_pre_blackwell, skip_pre_hopper) from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod, - CutlassFusedMoE, - DefaultMoeRoutingMethod, - RenormalizeMoeRoutingMethod, - VanillaMoE, WideEPMoE) from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \ CuteDslFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \ @@ -31,6 +27,18 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \ from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \ AlltoallMethodType from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode + +# isort and yapf will fight against each other here, so we disable isort +# isort: off +from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod, + CutlassFusedMoE, + DefaultMoeRoutingMethod, + RenormalizeMoeRoutingMethod, + TritonFusedMoE, VanillaMoE, + create_moe, WideEPMoE) +# isort: on +from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ + IS_TRITON_KERNELS_AVAILABLE from tensorrt_llm._torch.modules.gated_mlp import GatedMLP from tensorrt_llm._utils import mpi_rank from tensorrt_llm.mapping import Mapping @@ -46,65 +54,93 @@ MPI.pickle.__init__( @pytest.mark.parametrize( - "moe_cls, dtype, experts, RoutingMethodCls", - product([CutlassFusedMoE, VanillaMoE], [torch.float16, torch.bfloat16], - [3, 8, 512], - [DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod])) -def test_fused_moe(moe_cls, dtype, experts, RoutingMethodCls, mapping=None): - SEQ_LEN = 8 - HIDDEN_SIZE = 64 - INTERMEDIATE_SIZE = 32 - NUM_EXPERTS = experts - TOP_K = 2 - routing_method = RoutingMethodCls(top_k=TOP_K) - mapping = mapping or Mapping() + "moe_backend, dtype, experts, routing_cls, bias", + product(["CUTLASS", "VANILLA", "TRITON"], [torch.float16, torch.bfloat16], + [3, 8, 512], [DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod], + [True, False])) +def test_fused_moe(moe_backend, + dtype, + experts, + routing_cls, + bias, + mapping=None): + + if moe_backend == "TRITON": + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + if dtype != torch.bfloat16: + pytest.skip("Unsupported for TritonFusedMoE") + if routing_cls != RenormalizeMoeRoutingMethod: + pytest.skip("Unsupported for TritonFusedMoE") + + if bias and moe_backend not in ["TRITON"]: + pytest.skip("Bias not supported.") + + mapping = Mapping() mapping.rank = mpi_rank() - torch.cuda.set_device(mapping.rank) - torch.manual_seed(0) - torch.cuda.manual_seed(0) - x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") - router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), - dtype=dtype, - device="cuda") - weights = {} - for expert_id in range(NUM_EXPERTS): - w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype, - device="cuda") - w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), - dtype=dtype, - device="cuda") - w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype, - device="cuda") - weights[f"{expert_id}.w1.weight"] = w1_weight - weights[f"{expert_id}.w2.weight"] = w2_weight - weights[f"{expert_id}.w3.weight"] = w3_weight - fused_moe = moe_cls( - num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - reduce_results=True, - model_config=ModelConfig(mapping=mapping), - ) - fused_moe.load_weights([weights]) - fused_moe.cuda() + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 8 + HIDDEN_SIZE = 64 + INTERMEDIATE_SIZE = 32 + NUM_EXPERTS = experts + TOP_K = 2 + routing_method = routing_cls(top_k=TOP_K) - AutoTuner.get().clear_cache() - with torch.inference_mode(), autotune(): - fused_moe.forward(x, router_logits) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") - ref_fused_moe = RefGatedMLPFusedMoE(num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - model_config=ModelConfig()) - ref_fused_moe.load_weights([weights]) - ref_fused_moe.cuda() + weights = {} + for expert_id in range(NUM_EXPERTS): + if bias: + w1_bias = torch.randn((INTERMEDIATE_SIZE, ), dtype=dtype).cuda() + w2_bias = torch.randn((HIDDEN_SIZE, ), dtype=dtype).cuda() + w3_bias = torch.randn((INTERMEDIATE_SIZE, ), dtype=dtype).cuda() + weights[f"{expert_id}.w1.bias"] = w1_bias + weights[f"{expert_id}.w2.bias"] = w2_bias + weights[f"{expert_id}.w3.bias"] = w3_bias + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + weights[f"{expert_id}.w1.weight"] = w1_weight + weights[f"{expert_id}.w2.weight"] = w2_weight + weights[f"{expert_id}.w3.weight"] = w3_weight + fused_moe = create_moe( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, moe_backend=moe_backend), + bias=bias, + ) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + AutoTuner.get().clear_cache() + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + ref_fused_moe = RefGatedMLPFusedMoE(num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(), + bias=bias) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() # Evaluate the outputs on a variant sequence length to cover all possible keys in Autotuner cache m = SEQ_LEN @@ -120,7 +156,10 @@ def test_fused_moe(moe_cls, dtype, experts, RoutingMethodCls, mapping=None): # Evaluate outputs torch.cuda.synchronize() - torch.testing.assert_close(output, ref_output, rtol=0.5, atol=0.5) + # There can be one off mismatch in the outputs due to different kernel implementations + # Here we check 99% of the outputs are within the tolerance + # The CutlassFusedMoE case fails as well without this change on H100 for bf16 + check_accuracy(output, ref_output, rtol=0.2, atol=0.2, percent=0.984) m //= 2 @@ -250,94 +289,309 @@ def test_fused_moe_alltoall(alltoall_method_type): assert r is None -@skip_pre_hopper -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_fp8(dtype): - SEQ_LEN = 4 - HIDDEN_SIZE = 64 - INTERMEDIATE_SIZE = 32 - NUM_EXPERTS = 3 - TOP_K = 2 - routing_method = DefaultMoeRoutingMethod(top_k=TOP_K) +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +@pytest.mark.parametrize("alltoall_method_type", [ + AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP, + AlltoallMethodType.DeepEPLowLatency +], + ids=lambda s: s.name) +def test_fused_moe_alltoall_fp4(alltoall_method_type): + + world_size = 4 + dtype = torch.bfloat16 + HIDDEN_SIZE = 2560 + INTERMEDIATE_SIZE = 1536 + NUM_EXPERTS = 72 + TOP_K = 6 + MAX_NUM_TOKENS = 2048 + torch.manual_seed(0) torch.cuda.manual_seed(0) - x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") - _, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x) - x_scale = x_scale.float().squeeze() - router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), - dtype=dtype, - device="cuda") - weights = {} - for expert_id in range(NUM_EXPERTS): - w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype, - device="cuda") - w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), - dtype=dtype, - device="cuda") - w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype, - device="cuda") + x_list_world = [] + weights_world = [] - w1_weight_fp8, w1_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( - w1_weight) - w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + for i in range(world_size): + x_list = [] + m = MAX_NUM_TOKENS + while m >= 1: + x = torch.randn((m, HIDDEN_SIZE), dtype=dtype, device="cuda") + x_list.append(x.cuda(i)) + m //= 2 - w2_weight_fp8, w2_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( - w2_weight) - w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + x_abs_max = torch.cat([x.flatten() for x in x_list]).abs().max().float() + x_sf_global = (448 * 6) / x_abs_max - w3_weight_fp8, w3_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( - w3_weight) - w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + weights = {} + for expert_id in range(NUM_EXPERTS): - w1_input_scale = x_scale.cuda() - w2_input_scale = x_scale.cuda() - w3_input_scale = x_scale.cuda() + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w1_sf_global = (448 * 6) / w1_weight.abs().max().float() - weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 - weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 - weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 - weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale.float() - weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale.float() - weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale.float() - weights[f"{expert_id}.w1.input_scale"] = w1_input_scale - weights[f"{expert_id}.w2.input_scale"] = w2_input_scale - weights[f"{expert_id}.w3.input_scale"] = w3_input_scale + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w2_sf_global = (448 * 6) / w2_weight.abs().max().float() - quant_config = QuantConfig(quant_algo=QuantAlgo.FP8) - fused_moe = CutlassFusedMoE( - num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - reduce_results=False, - model_config=ModelConfig(quant_config=quant_config)) - fused_moe.cuda() - fused_moe.load_weights([weights]) + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w3_sf_global = (448 * 6) / w3_weight.abs().max().float() - AutoTuner.get().clear_cache() - with torch.inference_mode(), autotune(): - fused_moe.forward(x, router_logits) + w3_w1_global = min( + w1_sf_global, + w3_sf_global) # w3 global and w1 global must be the same - ref_fused_moe = RefGatedMLPFusedMoE( - num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - model_config=ModelConfig(quant_config=quant_config)) - ref_fused_moe.load_weights([weights]) - ref_fused_moe.cuda() - with torch.inference_mode(): - output = fused_moe.forward(x, router_logits) - ref_output = ref_fused_moe.forward(x, router_logits) + SCALING_VECTOR_SIZE = 16 - # compare - torch.cuda.synchronize() - torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.2) + w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize( + w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) + w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) + + w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize( + w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False) + w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + w2_sf_block.cpu().view(HIDDEN_SIZE, -1)) + + w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize( + w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) + w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) + + w1_input_scale = x_sf_global.cuda(i) + w2_input_scale = x_sf_global.cuda(i) + w3_input_scale = x_sf_global.cuda(i) + + weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4.cuda(i) + weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4.cuda(i) + weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4.cuda(i) + weights[ + f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.cuda(i) + weights[ + f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.cuda(i) + weights[ + f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.cuda(i) + + weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale.cuda( + i) + weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale.cuda( + i) + weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale.cuda( + i) + weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global.cuda( + i) + weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global.cuda( + i) + weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global.cuda( + i) + + x_list_world.append(x_list) + weights_world.append(weights) + + def per_rank_test_fused_moe_alltoall(job_id): + routing_method = DefaultMoeRoutingMethod(top_k=TOP_K) + mapping = Mapping(world_size=world_size, + rank=mpi_rank(), + tp_size=world_size, + moe_ep_size=world_size, + moe_tp_size=1, + enable_attention_dp=True) + torch.cuda.set_device(mapping.rank) + torch.manual_seed(mapping.rank) + + x_list = x_list_world[mapping.rank] + weights = weights_world[mapping.rank] + + quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + with mock.patch.object(WideEPMoE, + "select_alltoall_method_type", + return_value=alltoall_method_type): + alltoall_model = WideEPMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS, + quant_config=quant_config), + ) + alltoall_model.to("cuda") + alltoall_model.load_weights([weights]) + + ref_model = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS, + quant_config=quant_config), + ) + ref_model.to("cuda") + ref_model.load_weights([weights]) + + # Evaluate the outputs on a variant sequence length to verify the robustness of alltoall methods + m = MAX_NUM_TOKENS + i = 0 + while m >= 1: + x = x_list[i] + i += 1 + router_logits = torch.randn((m, NUM_EXPERTS), + dtype=dtype, + device="cuda") + all_rank_num_tokens = [m] * mapping.world_size + + with torch.inference_mode(): + output = alltoall_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=m, + use_dp_padding=False) + ref_output = ref_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=m, + use_dp_padding=False) + + # Evaluate outputs + torch.testing.assert_close(output, ref_output, rtol=0.05, atol=0.5) + m //= 2 + + with MPIPoolExecutor(max_workers=world_size) as executor: + results = executor.map(per_rank_test_fused_moe_alltoall, + range(world_size)) + for r in results: + assert r is None + + +@skip_pre_hopper +@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON"]) +@pytest.mark.parametrize("routing_cls", + [DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_fp8(moe_backend, dtype, routing_cls, bias): + + if moe_backend == "TRITON": + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + if dtype != torch.bfloat16: + pytest.skip("Unsupported for TritonFusedMoE") + if routing_cls != RenormalizeMoeRoutingMethod: + pytest.skip("Unsupported for TritonFusedMoE") + + if bias and moe_backend not in ["TRITON"]: + pytest.skip("Bias not supported.") + + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 4 + HIDDEN_SIZE = 64 + INTERMEDIATE_SIZE = 32 + NUM_EXPERTS = 3 + TOP_K = 2 + routing_method = routing_cls(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + _, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x) + x_scale = x_scale.float().squeeze() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + weights = {} + for expert_id in range(NUM_EXPERTS): + if bias: + w1_bias = torch.randn((INTERMEDIATE_SIZE, ), dtype=dtype).cuda() + w2_bias = torch.randn((HIDDEN_SIZE, ), dtype=dtype).cuda() + w3_bias = torch.randn((INTERMEDIATE_SIZE, ), dtype=dtype).cuda() + weights[f"{expert_id}.w1.bias"] = w1_bias + weights[f"{expert_id}.w2.bias"] = w2_bias + weights[f"{expert_id}.w3.bias"] = w3_bias + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + + w1_weight_fp8, w1_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( + w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( + w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor( + w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w1_input_scale = x_scale.cuda() + w2_input_scale = x_scale.cuda() + w3_input_scale = x_scale.cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale.float() + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale.float() + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale.float() + weights[f"{expert_id}.w1.input_scale"] = w1_input_scale + weights[f"{expert_id}.w2.input_scale"] = w2_input_scale + weights[f"{expert_id}.w3.input_scale"] = w3_input_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8) + fused_moe = create_moe(num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=False, + model_config=ModelConfig( + quant_config=quant_config, + moe_backend=moe_backend), + bias=bias) + fused_moe.cuda() + fused_moe.load_weights([weights]) + + AutoTuner.get().clear_cache() + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + ref_fused_moe = RefGatedMLPFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(quant_config=quant_config), + bias=bias) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref_fused_moe.forward(x, router_logits) + + # compare + torch.cuda.synchronize() + check_accuracy(output, ref_output, rtol=0.04, atol=0.1, percent=0.99) def set_tensor_value_2(x, num_row, num_cols): @@ -560,7 +814,7 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype, torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) -@skip_non_hopper_unittest +@skip_pre_blackwell @pytest.mark.parametrize( "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode", product( @@ -572,13 +826,13 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype, [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ], ), ) -def test_fused_moe_fp8_blockwise(dtype, - num_experts, - seq_len, - hidden_size, - RoutingMethodCls, - WeightLoadingMode, - mapping=None): +def test_fused_moe_fp8_blockwise_cute_dsl(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + WeightLoadingMode, + mapping=None): SEQ_LEN = seq_len HIDDEN_SIZE = hidden_size INTERMEDIATE_SIZE = 1536 @@ -671,7 +925,128 @@ def test_fused_moe_fp8_blockwise(dtype, fused_moe.cuda() fused_moe.load_weights([weights]) - fused_moe_origin = CutlassFusedMoE( + ref_fused_moe = RefGatedMLPFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(quant_config=quant_config), + # Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here + use_cute_dsl_blockscaling_mm=True, + ) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref_fused_moe.forward(x, router_logits) + + # compare + torch.cuda.synchronize() + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + return True + + +@skip_non_hopper_unittest +@pytest.mark.parametrize( + "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode", + product( + [torch.bfloat16], + [72], + [128, 256, 384, 512, 1024, 2048, 4096, 8192], + [2560], + [DefaultMoeRoutingMethod], + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ], + ), +) +def test_fused_moe_fp8_blockwise_cutlass(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + WeightLoadingMode, + mapping=None): + SEQ_LEN = seq_len + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = 1536 + NUM_EXPERTS = num_experts + TOP_K = 6 + + routing_method = RoutingMethodCls(top_k=TOP_K) + + mapping = mapping or Mapping() + mapping.rank = mpi_rank() + torch.cuda.set_device(mapping.rank) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + # Note: we use some special values init x and weight, otherwise the test will false positive failed. + set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE) + + x = x.cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + weights = {} + + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'] = {} + weights['down_proj'] = {} + weights['gate_up_proj_weight_scale'] = {} + weights['down_proj_weight_scale'] = {} + + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8(w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8(w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8(w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'][expert_id] = torch.cat( + [w3_weight_fp8, w1_weight_fp8], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj'][expert_id] = w2_weight_fp8.transpose( + 0, 1).contiguous() + weights['gate_up_proj_weight_scale'][expert_id] = torch.cat( + [w3_weight_scale, w1_weight_scale], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj_weight_scale'][ + expert_id] = w2_weight_scale.transpose(0, 1).contiguous() + elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA: + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + + fused_moe = CutlassFusedMoE( num_experts=NUM_EXPERTS, routing_method=routing_method, hidden_size=HIDDEN_SIZE, @@ -681,8 +1056,8 @@ def test_fused_moe_fp8_blockwise(dtype, model_config=ModelConfig(quant_config=quant_config, mapping=mapping), weight_loading_mode=WeightLoadingMode, ) - fused_moe_origin.cuda() - fused_moe_origin.load_weights([weights]) + fused_moe.cuda() + fused_moe.load_weights([weights]) ref_fused_moe = RefGatedMLPFusedMoE( num_experts=NUM_EXPERTS, @@ -697,13 +1072,10 @@ def test_fused_moe_fp8_blockwise(dtype, with torch.inference_mode(): output = fused_moe.forward(x, router_logits) - output_origin = fused_moe_origin.forward(x, router_logits) ref_output = ref_fused_moe.forward(x, router_logits) # compare torch.cuda.synchronize() - torch.testing.assert_close(output_origin, output, rtol=1e-2, atol=0.1) - torch.testing.assert_close(output_origin, ref_output, rtol=1e-2, atol=0.1) torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) return True @@ -713,17 +1085,55 @@ def test_fused_moe_fp8_blockwise(dtype, reason="needs 4 GPUs to run this test") @pytest.mark.parametrize("ep_size", [1, 2, 4]) @pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod]) -def test_fused_moe_fp8_blockwise_multi_gpu(ep_size, routing_method): +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ]) +def test_fused_moe_fp8_blockwise_cutlass_multi_gpu(ep_size, routing_method, + weight_loading_mode): world_size = 4 with MPIPoolExecutor(max_workers=world_size) as executor: results = executor.map( - test_fused_moe_fp8_blockwise, + test_fused_moe_fp8_blockwise_cutlass, *zip(*[( torch.bfloat16, 72, 384, 384, routing_method, + weight_loading_mode, + Mapping( + world_size=world_size, + tp_size=world_size, + moe_ep_size=ep_size, + moe_tp_size=world_size // ep_size, + ), + )] * world_size), + ) + for r in results: + assert r is True + + +@skip_pre_blackwell +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +@pytest.mark.parametrize("ep_size", [1, 2, 4]) +@pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod]) +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ]) +def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method, + weight_loading_mode): + world_size = 4 + with MPIPoolExecutor(max_workers=world_size) as executor: + results = executor.map( + test_fused_moe_fp8_blockwise_cute_dsl, + *zip(*[( + torch.bfloat16, + 72, + 384, + 384, + routing_method, + weight_loading_mode, Mapping( world_size=world_size, tp_size=world_size, @@ -739,89 +1149,458 @@ def test_fused_moe_fp8_blockwise_multi_gpu(ep_size, routing_method): @skip_pre_blackwell @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_moe_nvfp4(dtype): - SCALING_VECTOR_SIZE = 16 + mapping = Mapping() + mapping.rank = mpi_rank() - SEQ_LEN = 4 - HIDDEN_SIZE = 128 - INTERMEDIATE_SIZE = 128 - NUM_EXPERTS = 3 - TOP_K = 2 - routing_method = DefaultMoeRoutingMethod(top_k=TOP_K) + with torch.device(f'cuda:{mapping.rank}'): + SCALING_VECTOR_SIZE = 16 + + SEQ_LEN = 4 + HIDDEN_SIZE = 128 + INTERMEDIATE_SIZE = 128 + NUM_EXPERTS = 3 + TOP_K = 2 + routing_method = DefaultMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + x_sf_global = (448 * 6) / x.abs().max().float() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w1_sf_global = (448 * 6) / w1_weight.abs().max().float() + + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w2_sf_global = (448 * 6) / w2_weight.abs().max().float() + + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w3_sf_global = (448 * 6) / w3_weight.abs().max().float() + + w3_w1_global = min( + w1_sf_global, + w3_sf_global) # w3 global and w1 global must be the same + + w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize( + w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) + w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) + + w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize( + w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False) + w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + w2_sf_block.cpu().view(HIDDEN_SIZE, -1)) + + w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize( + w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) + w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) + + w1_input_scale = x_sf_global.cuda() + w2_input_scale = x_sf_global.cuda() + w3_input_scale = x_sf_global.cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4 + weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4 + weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4 + weights[ + f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.view( + torch.float8_e4m3fn).cuda() + weights[ + f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.view( + torch.float8_e4m3fn).cuda() + weights[ + f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.view( + torch.float8_e4m3fn).cuda() + weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale + weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale + weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale + weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global + weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global + weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global + + quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + fused_moe = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=False, + model_config=ModelConfig(quant_config=quant_config)) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + # Evaluate the outputs on a variant sequence length to cover all possible keys in Autotuner cache + ref_fused_moe = RefGatedMLPFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(quant_config=quant_config)) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() + + AutoTuner.get().clear_cache() + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref_fused_moe.forward(x, router_logits) + + # compare + torch.cuda.synchronize() + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15) + + +@skip_neither_ada_nor_hopper_unittest +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.W4A8_CUSTOM]) +def test_fused_moe_w4afp8(dtype, weight_loading_mode): + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 4 + HIDDEN_SIZE = 768 + INTERMEDIATE_SIZE = 640 + SCALING_GROUP_SIZE = 128 + NUM_EXPERTS = 3 + TOP_K = 2 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + affine_coeff = 0.005 + + lut = { + "weight": + "weight", + "weight_scale": + ("weight_scale_inv" if weight_loading_mode + == MoEWeightLoadingMode.W4A8_CUSTOM else "weight_scale"), + "weight_scale_2": + "weight_scale_2", + "pre_quant_scale": + "pre_quant_scale", + "input_scale": + "input_scale", + } + + weights = {} + for expert_id in range(NUM_EXPERTS): + # ModelOpt W4A8 packs pairs of 4b weights in the output dimension into one 8b element. + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_shape = (INTERMEDIATE_SIZE // 2, HIDDEN_SIZE) + w2_shape = (HIDDEN_SIZE // 2, INTERMEDIATE_SIZE) + w3_shape = (INTERMEDIATE_SIZE // 2, HIDDEN_SIZE) + # The custom W4A8 quantization script examples/quantization/quantize_mixed_precision_moe.py + # packs pairs of 4b weight in the input dimension into one 8b element. + if weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: + w1_shape = (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2) + w2_shape = (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2) + w3_shape = (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2) + + # The weights in int4 precision. + w1_weight = torch.randint(-128, 127, w1_shape, + dtype=torch.int8).cuda() + w2_weight = torch.randint(-128, 127, w2_shape, + dtype=torch.int8).cuda() + w3_weight = torch.randint(-128, 127, w3_shape, + dtype=torch.int8).cuda() + + # The pre-quant scale to be multiplied with the input activation. + w1_pre_quant_scale = torch.ones(HIDDEN_SIZE, + dtype=dtype, + device="cuda") + w2_pre_quant_scale = torch.ones(INTERMEDIATE_SIZE, + dtype=dtype, + device="cuda") + w3_pre_quant_scale = torch.ones(HIDDEN_SIZE, + dtype=dtype, + device="cuda") + + # The weight scale to dequantize int4 weights (by multiplication). + w1_scale = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda") * affine_coeff + w2_scale = torch.randn( + (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda") * affine_coeff + w3_scale = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda") * affine_coeff + + # The input scale to quantize the input activation (by division). + w1_input_scale = torch.randn(1, dtype=torch.float32, + device="cuda") * 0.2 + w2_input_scale = w1_input_scale + w3_input_scale = w1_input_scale + + # The weight scale 2 to quantize the dequantized weights (by division). + w1_weight_scale_2 = torch.ones([1], + dtype=torch.float32, + device="cuda") + w2_weight_scale_2 = w1_weight_scale_2 + w3_weight_scale_2 = w1_weight_scale_2 + + # Prepare weights. + weights[f"{expert_id}.w1.{lut['weight']}"] = w1_weight + weights[f"{expert_id}.w2.{lut['weight']}"] = w2_weight + weights[f"{expert_id}.w3.{lut['weight']}"] = w3_weight + weights[f"{expert_id}.w1.{lut['input_scale']}"] = w1_input_scale + weights[f"{expert_id}.w2.{lut['input_scale']}"] = w2_input_scale + weights[f"{expert_id}.w3.{lut['input_scale']}"] = w3_input_scale + weights[f"{expert_id}.w1.{lut['weight_scale']}"] = w1_scale + weights[f"{expert_id}.w2.{lut['weight_scale']}"] = w2_scale + weights[f"{expert_id}.w3.{lut['weight_scale']}"] = w3_scale + weights[ + f"{expert_id}.w1.{lut['pre_quant_scale']}"] = w1_pre_quant_scale + weights[ + f"{expert_id}.w2.{lut['pre_quant_scale']}"] = w2_pre_quant_scale + weights[ + f"{expert_id}.w3.{lut['pre_quant_scale']}"] = w3_pre_quant_scale + weights[ + f"{expert_id}.w1.{lut['weight_scale_2']}"] = w1_weight_scale_2 + weights[ + f"{expert_id}.w2.{lut['weight_scale_2']}"] = w2_weight_scale_2 + weights[ + f"{expert_id}.w3.{lut['weight_scale_2']}"] = w3_weight_scale_2 + + quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_AWQ) + fused_moe = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=False, + model_config=ModelConfig(quant_config=quant_config), + weight_loading_mode=weight_loading_mode) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + def ref(): + results = torch.zeros_like(x) + selected_experts, final_scales = routing_method.apply(router_logits) + for e_idx in range(NUM_EXPERTS): + mask = selected_experts == e_idx + activated_tokens = mask.sum(1).bool() + act = x[activated_tokens, :] + if act.shape[0] == 0: + continue + final_scale = (final_scales * + mask).sum(1)[activated_tokens].unsqueeze(1) + + # weights + def unpack_weights(weight: torch.Tensor) -> torch.Tensor: + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + return unpacker(weight.cpu().T.contiguous()).cuda() + else: + return unpacker(weight.cpu()).T.contiguous().cuda() + + w1 = unpack_weights(weights[f"{e_idx}.w1.{lut['weight']}"]) + w2 = unpack_weights(weights[f"{e_idx}.w2.{lut['weight']}"]) + w3 = unpack_weights(weights[f"{e_idx}.w3.{lut['weight']}"]) + w3_w1 = torch.cat([w3, w1], dim=-1) + + # weight_scale + s1 = weights[f"{e_idx}.w1.{lut['weight_scale']}"].T.contiguous( + ).cuda() + s2 = weights[f"{e_idx}.w2.{lut['weight_scale']}"].T.contiguous( + ).cuda() + s3 = weights[f"{e_idx}.w3.{lut['weight_scale']}"].T.contiguous( + ).cuda() + s3_s1 = torch.cat([s3, s1], dim=-1) + + # input_scale + p1 = weights[f"{e_idx}.w1.{lut['input_scale']}"].cuda() + p2 = weights[f"{e_idx}.w2.{lut['input_scale']}"].cuda() + p3 = weights[f"{e_idx}.w3.{lut['input_scale']}"].cuda() + p3_p1 = torch.max(p1, p3) + + # pre_quant_scale + a1 = a2 = a3 = a1_a3 = None + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + a1 = weights[ + f"{e_idx}.w1.{lut['pre_quant_scale']}"].T.contiguous( + ).cuda() + a2 = weights[ + f"{e_idx}.w2.{lut['pre_quant_scale']}"].T.contiguous( + ).cuda() + a3 = weights[ + f"{e_idx}.w3.{lut['pre_quant_scale']}"].T.contiguous( + ).cuda() + a1_a3 = torch.max(a1, a3) + + # weight_scale_2 + q1 = q2 = q3 = q3_q1 = None + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + q1 = weights[f"{e_idx}.w1.{lut['weight_scale_2']}"].cuda() + q2 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda() + q3 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda() + q3_q1 = torch.max(q3, q1) + + # forward pass + def process_layer( + act, + weight, + weight_scale, + input_scale, + pre_quant_scale=None, + weight_scale_2=None, + ): + if pre_quant_scale is not None: + act = act * pre_quant_scale + act = (torch.clamp((act / input_scale), -448.0, + 448.0).to(torch.float8_e4m3fn).to(dtype)) + weight = (weight.float() * weight_scale.repeat_interleave( + 128, dim=0).float()).to(dtype) + if weight_scale_2 is not None: + weight /= weight_scale_2 + output = torch.matmul(act, weight) * input_scale + if weight_scale_2 is not None: + output *= weight_scale_2 + return output + + # fc13 + fc1 = process_layer( + act, + w3_w1, + s3_s1, + p3_p1, + pre_quant_scale=a1_a3, + weight_scale_2=q3_q1, + ) + fc1, gate = fc1.chunk(2, dim=-1) + fc1 = fc1 * torch.nn.functional.silu(gate) + + # fc2 + fc2 = process_layer(fc1, + w2, + s2, + p2, + pre_quant_scale=a2, + weight_scale_2=q2) + + results[activated_tokens, :] += (fc2 * final_scale).to( + results.dtype) + return results + + AutoTuner.get().clear_cache() + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + torch.cuda.synchronize() + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref() + + torch.cuda.synchronize() + # assert that result does not contain NaN or is all 0s + assert not torch.isnan(ref_output).any(), "ref_output contains NaN" + assert not torch.isnan(output).any(), "output contains NaN" + assert torch.nonzero(output).numel() > 0, "output is empty" + assert torch.nonzero(ref_output).numel() > 0, "ref_output is empty" + # compare + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + + +@skip_pre_blackwell +@pytest.mark.parametrize("moe_backend", ["TRTLLM", "CUTLASS"]) +@pytest.mark.parametrize("bias", [True, False]) +def test_fused_moe_mxfp4_mxpf8(moe_backend, bias): + SCALING_VECTOR_SIZE = 32 + dtype = torch.bfloat16 + SEQ_LEN = 128 + HIDDEN_SIZE = 256 + INTERMEDIATE_SIZE = 256 + NUM_EXPERTS = 8 + TOP_K = 1 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) torch.manual_seed(0) torch.cuda.manual_seed(0) - x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") - x_sf_global = (448 * 6) / x.abs().max().float() - router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), - dtype=dtype, - device="cuda") + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() * 0.1 + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() weights = {} for expert_id in range(NUM_EXPERTS): - w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype, - device="cuda") - w1_sf_global = (448 * 6) / w1_weight.abs().max().float() - - w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), - dtype=dtype, - device="cuda") - w2_sf_global = (448 * 6) / w2_weight.abs().max().float() - + if bias: + w1_bias = torch.randn( + (INTERMEDIATE_SIZE, ), dtype=dtype).cuda() * 0.1 + w2_bias = torch.randn((HIDDEN_SIZE, ), dtype=dtype).cuda() * 0.1 + w3_bias = torch.randn( + (INTERMEDIATE_SIZE, ), dtype=dtype).cuda() * 0.1 + weights[f"{expert_id}.w1.bias"] = w1_bias + weights[f"{expert_id}.w2.bias"] = w2_bias + weights[f"{expert_id}.w3.bias"] = w3_bias + w1_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() * 0.1 + w2_weight = torch.randn( + (HIDDEN_SIZE, INTERMEDIATE_SIZE), dtype=dtype).cuda() * 0.1 w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype, - device="cuda") - w3_sf_global = (448 * 6) / w3_weight.abs().max().float() + dtype=dtype).cuda() - w3_w1_global = min( - w1_sf_global, - w3_sf_global) # w3 global and w1 global must be the same - - w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize( - w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) - w1_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse( + w1_weight_mxfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize( + w1_weight, None, SCALING_VECTOR_SIZE, True) + w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) - w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize( - w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False) - w2_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse( + w2_weight_mxfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize( + w2_weight, None, SCALING_VECTOR_SIZE, True) + w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( w2_sf_block.cpu().view(HIDDEN_SIZE, -1)) - w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize( - w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False) - w3_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse( + w3_weight_mxfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize( + w3_weight, None, SCALING_VECTOR_SIZE, True) + w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1)) - w1_input_scale = x_sf_global.cuda() - w2_input_scale = x_sf_global.cuda() - w3_input_scale = x_sf_global.cuda() - - weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4 - weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4 - weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4 + weights[f"{expert_id}.w1.weight"] = w1_weight_mxfp4 + weights[f"{expert_id}.w2.weight"] = w2_weight_mxfp4 + weights[f"{expert_id}.w3.weight"] = w3_weight_mxfp4 weights[f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.view( - torch.float8_e4m3fn).cuda() + torch.uint8).cuda() weights[f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.view( - torch.float8_e4m3fn).cuda() + torch.uint8).cuda() weights[f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.view( - torch.float8_e4m3fn).cuda() - weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale - weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale - weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale - weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global - weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global - weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global + torch.uint8).cuda() - quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) - fused_moe = CutlassFusedMoE( + quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_MXFP4_MXFP8) + fused_moe = create_moe( num_experts=NUM_EXPERTS, routing_method=routing_method, hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, dtype=dtype, - reduce_results=False, - model_config=ModelConfig(quant_config=quant_config)) - fused_moe.load_weights([weights]) + reduce_results=True, + model_config=ModelConfig(quant_config=quant_config, + moe_backend=moe_backend), + bias=bias, + ) fused_moe.cuda() + fused_moe.load_weights([weights]) # Evaluate the outputs on a variant sequence length to cover all possible keys in Autotuner cache ref_fused_moe = RefGatedMLPFusedMoE( @@ -830,9 +1609,10 @@ def test_fused_moe_nvfp4(dtype): hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, dtype=dtype, + bias=bias, model_config=ModelConfig(quant_config=quant_config)) - ref_fused_moe.load_weights([weights]) ref_fused_moe.cuda() + ref_fused_moe.load_weights([weights]) AutoTuner.get().clear_cache() with torch.inference_mode(), autotune(): @@ -844,143 +1624,397 @@ def test_fused_moe_nvfp4(dtype): # compare torch.cuda.synchronize() - torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15) + + +@skip_non_hopper_unittest +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [768, 2880]) +def test_fused_moe_wfp4a16(dtype, hidden_size): + + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 4 + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = 640 + SCALING_GROUP_SIZE = 32 + NUM_EXPERTS = 3 + TOP_K = 2 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() + + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randint(0, + 256, + (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), + dtype=torch.uint8, + device='cuda') + w2_weight = torch.randint(0, + 256, + (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2), + dtype=torch.uint8, + device='cuda') + w3_weight = torch.randint(0, + 256, + (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), + dtype=torch.uint8, + device='cuda') + + w1_scale = torch.randint( + 118, + 123, (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=torch.uint8, + device='cuda') + w2_scale = torch.randint( + 118, + 123, (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), + dtype=torch.uint8, + device='cuda') + w3_scale = torch.randint( + 118, + 123, (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=torch.uint8, + device='cuda') + + weights[f"{expert_id}.w1.weight"] = w1_weight + weights[f"{expert_id}.w2.weight"] = w2_weight + weights[f"{expert_id}.w3.weight"] = w3_weight + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_MXFP4) + fused_moe = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=False, + model_config=ModelConfig(quant_config=quant_config)) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + def ref(): + results = torch.zeros_like(x) + selected_experts, final_scales = routing_method.apply(router_logits) + unpacker = torch.ops.trtllm.mxfp4_dequantize_unswizzled + for e_idx in range(NUM_EXPERTS): + mask = selected_experts == e_idx + activated_tokens = mask.sum(1).bool() + act = x[activated_tokens, :] + if act.shape[0] == 0: + continue + final_scale = (final_scales * + mask).sum(1)[activated_tokens].unsqueeze(1) + + # weights and scales + w1 = weights[f"{e_idx}.w1.weight"] + s1 = weights[f"{e_idx}.w1.weight_scale_inv"] + w2 = weights[f"{e_idx}.w2.weight"] + s2 = weights[f"{e_idx}.w2.weight_scale_inv"] + w3 = weights[f"{e_idx}.w3.weight"] + s3 = weights[f"{e_idx}.w3.weight_scale_inv"] + + # converted weights + w1 = unpacker(w1.cpu(), s1.cpu(), SCALING_GROUP_SIZE).to( + dtype=x.dtype, device=x.device).T.contiguous() + w2 = unpacker(w2.cpu(), s2.cpu(), SCALING_GROUP_SIZE).to( + dtype=x.dtype, device=x.device).T.contiguous() + w3 = unpacker(w3.cpu(), s3.cpu(), SCALING_GROUP_SIZE).to( + dtype=x.dtype, device=x.device).T.contiguous() + w3_w1 = torch.cat([w3, w1], dim=-1) + + fc1 = torch.matmul(act, w3_w1) + fc1, gate = fc1.chunk(2, dim=-1) + fc1 = fc1 * torch.nn.functional.silu(gate) + fc2 = torch.matmul(fc1, w2) + results[activated_tokens, :] += (fc2 * final_scale).to( + results.dtype) + return results + + AutoTuner.get().clear_cache() + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + torch.cuda.synchronize() + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref() + + # compare + torch.cuda.synchronize() + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + + +@skip_pre_hopper +@pytest.mark.parametrize("experts", [8, 128]) +@pytest.mark.parametrize( + "hidden_size, intermediate_size", + [ + (2880, 2880), + (2880, 1440), + (2880, 720), + (2880, 360), + ], +) +@pytest.mark.parametrize("fp8_activation", [True, False]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("dynamic_quant", [True, False]) +def test_fused_moe_triton_mxfp4(experts, hidden_size, intermediate_size, + fp8_activation, bias, dynamic_quant): + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + if torch.cuda.get_device_capability()[0] < 10 and fp8_activation: + pytest.skip("Latest Triton requires BF16 activation on Hopper") + if torch.cuda.get_device_capability()[0] >= 10 and not fp8_activation: + pytest.skip("Latest Triton requires FP8 activation on Blackwell") + + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + dtype = torch.bfloat16 + SEQ_LEN = 8 + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = intermediate_size + NUM_EXPERTS = experts + TOP_K = 4 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() + + w1_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype).cuda() + w2_weight = torch.randn((NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + w3_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype).cuda() + w1_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + w2_bias = torch.randn((NUM_EXPERTS, HIDDEN_SIZE), dtype=dtype).cuda() + w3_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + + from triton_kernels.numerics_details.mxfp import ( + downcast_to_mxfp_torch, upcast_from_mxfp_torch) + + def fp32_to_mxfp4(tensor): + tensor = tensor.transpose(1, 2).contiguous() + tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor, + torch.uint8, + axis=1) + tensor_fp4 = tensor_fp4.transpose(1, 2).contiguous() + tensor_scales = tensor_scales.transpose(1, 2).contiguous() + return tensor_fp4, tensor_scales + + def mxfp4_to_fp32(tensor, scales): + tensor = tensor.transpose(1, 2).contiguous() + scales = scales.transpose(1, 2).contiguous() + tensor = upcast_from_mxfp_torch(tensor, + scales, + torch.float32, + axis=1) + return tensor.transpose(1, 2).contiguous() + + w1_weight_fp4, w1_weight_scale = fp32_to_mxfp4(w1_weight) + w2_weight_fp4, w2_weight_scale = fp32_to_mxfp4(w2_weight) + w3_weight_fp4, w3_weight_scale = fp32_to_mxfp4(w3_weight) + w1_weight_qdq = mxfp4_to_fp32(w1_weight_fp4, w1_weight_scale) + w2_weight_qdq = mxfp4_to_fp32(w2_weight_fp4, w2_weight_scale) + w3_weight_qdq = mxfp4_to_fp32(w3_weight_fp4, w3_weight_scale) + + # Since we don't have mxfp4 reference, we run the ref in bf16 after q-dq + weights = {} + for expert_id in range(NUM_EXPERTS): + weights[f"{expert_id}.w1.weight"] = w1_weight_qdq[expert_id] + weights[f"{expert_id}.w2.weight"] = w2_weight_qdq[expert_id] + weights[f"{expert_id}.w3.weight"] = w3_weight_qdq[expert_id] + if bias: + weights[f"{expert_id}.w1.bias"] = w1_bias[expert_id] + weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id] + weights[f"{expert_id}.w3.bias"] = w3_bias[expert_id] + + ref_fused_moe = RefGatedMLPFusedMoE(num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(), + bias=bias) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() + + with torch.inference_mode(): + ref_output = ref_fused_moe.forward(x, router_logits) + torch.cuda.synchronize() + + # Now we run the TritonFusedMoE with MXFP4 weights + weights = {} + + for expert_id in range(NUM_EXPERTS): + if dynamic_quant: + weights[f"{expert_id}.w1.weight"] = w1_weight_qdq[expert_id] + weights[f"{expert_id}.w2.weight"] = w2_weight_qdq[expert_id] + weights[f"{expert_id}.w3.weight"] = w3_weight_qdq[expert_id] + else: + weights[f"{expert_id}.w1.weight"] = w1_weight_fp4[expert_id] + weights[f"{expert_id}.w2.weight"] = w2_weight_fp4[expert_id] + weights[f"{expert_id}.w3.weight"] = w3_weight_fp4[expert_id] + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale[ + expert_id] + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale[ + expert_id] + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale[ + expert_id] + if bias: + weights[f"{expert_id}.w1.bias"] = w1_bias[expert_id] + weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id] + weights[f"{expert_id}.w3.bias"] = w3_bias[expert_id] + + quant_algo = QuantAlgo.W4A8_MXFP4_FP8 if fp8_activation else QuantAlgo.W4A16_MXFP4 + quant_config = QuantConfig(quant_algo=quant_algo) + fused_moe = TritonFusedMoE(num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + bias=bias, + model_config=ModelConfig( + quant_config=quant_config, + mapping=mapping)) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + torch.cuda.synchronize() + + # Evaluate outputs + + # There can be one off mismatch in the outputs due to different kernel implementations + # Here we check certain percent of the outputs are within the tolerance + check_accuracy(output, ref_output, rtol=0.6, atol=0.6, percent=0.945) -@skip_neither_ada_nor_hopper_unittest @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_w4afp8(dtype): +@pytest.mark.parametrize("weight_dtype", [torch.int8]) +def test_fused_moe_int8_woq_per_channel(dtype, weight_dtype): - SEQ_LEN = 4 - HIDDEN_SIZE = 768 - INTERMEDIATE_SIZE = 640 - SCALING_GROUP_SIZE = 128 - NUM_EXPERTS = 3 - TOP_K = 2 - routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) - torch.manual_seed(0) - torch.cuda.manual_seed(0) - x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") - router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), - dtype=dtype, - device="cuda") + mapping = Mapping() + mapping.rank = mpi_rank() - affine_coeff = 0.005 + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 4 + HIDDEN_SIZE = 768 + INTERMEDIATE_SIZE = 640 + NUM_EXPERTS = 3 + TOP_K = 2 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") - weights = {} - for expert_id in range(NUM_EXPERTS): - w1_weight = torch.randint(-128, - 127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), - dtype=torch.int8).cuda() - w2_weight = torch.randint(-128, - 127, (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2), - dtype=torch.int8).cuda() - w3_weight = torch.randint(-128, - 127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), - dtype=torch.int8).cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") - w1_scale = torch.randn( - (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + weight_id = 1 # 1 for w8a16, 2 for w4a16 + quant_config = QuantConfig(quant_algo=QuantAlgo.W8A16) + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randint( + -128, + 127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // weight_id), + dtype=torch.int8).cuda() + w2_weight = torch.randint( + -128, + 127, (HIDDEN_SIZE, INTERMEDIATE_SIZE // weight_id), + dtype=torch.int8).cuda() + w3_weight = torch.randint( + -128, + 127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // weight_id), + dtype=torch.int8).cuda() + + w1_scale = torch.randn( + (INTERMEDIATE_SIZE), dtype=dtype, device="cuda") / HIDDEN_SIZE + w2_scale = torch.randn( + (HIDDEN_SIZE), dtype=dtype, device="cuda") / INTERMEDIATE_SIZE + w3_scale = torch.randn( + (INTERMEDIATE_SIZE), dtype=dtype, device="cuda") / HIDDEN_SIZE + + weights[f"{expert_id}.w1.weight"] = w1_weight + weights[f"{expert_id}.w2.weight"] = w2_weight + weights[f"{expert_id}.w3.weight"] = w3_weight + weights[f"{expert_id}.w1.weight_scale"] = w1_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_scale + + fused_moe = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, dtype=dtype, - device="cuda") * affine_coeff - w2_scale = torch.randn( - (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), - dtype=dtype, - device="cuda") * affine_coeff - w3_scale = torch.randn( - (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), - dtype=dtype, - device="cuda") * affine_coeff + reduce_results=False, + model_config=ModelConfig(quant_config=quant_config)) + fused_moe.load_weights([weights]) + fused_moe.cuda() - w1_input = torch.randn(1, dtype=torch.float32, device="cuda") * 0.02 - w2_input = w1_input - w3_input = w1_input + def ref(): + results = torch.zeros_like(x) + selected_experts, final_scales = routing_method.apply(router_logits) + for e_idx in range(NUM_EXPERTS): + mask = selected_experts == e_idx + activated_tokens = mask.sum(1).bool() + act = x[activated_tokens, :] + if act.shape[0] == 0: + continue + final_scale = (final_scales * + mask).sum(1)[activated_tokens].unsqueeze(1) + # weights + w1 = weights[f"{e_idx}.w1.weight"].T.contiguous().cuda() + w2 = weights[f"{e_idx}.w2.weight"].T.contiguous().cuda() + w3 = weights[f"{e_idx}.w3.weight"].T.contiguous().cuda() + w3_w1 = torch.cat([w3, w1], dim=-1) + # scales + s1 = weights[f"{e_idx}.w1.weight_scale"].cuda() + s2 = weights[f"{e_idx}.w2.weight_scale"].cuda() + s3 = weights[f"{e_idx}.w3.weight_scale"].cuda() + s3_s1 = torch.cat([s3, s1], dim=-1) + # calculation + w3_w1 = (w3_w1.float() * s3_s1).to(dtype) + fc1 = torch.matmul(act, w3_w1) + fc1, gate = fc1.chunk(2, dim=-1) + act = fc1 * torch.nn.functional.silu(gate) + w2 = (w2.float() * s2).to(dtype) + fc2 = torch.matmul(act, w2) + results[activated_tokens, :] += (fc2 * final_scale).to( + results.dtype) + return results - weights[f"{expert_id}.w1.weight"] = w1_weight - weights[f"{expert_id}.w2.weight"] = w2_weight - weights[f"{expert_id}.w3.weight"] = w3_weight - weights[f"{expert_id}.w1.weight_scale_inv"] = w1_scale - weights[f"{expert_id}.w2.weight_scale_inv"] = w2_scale - weights[f"{expert_id}.w3.weight_scale_inv"] = w3_scale - weights[f"{expert_id}.w1.input_scale"] = w1_input - weights[f"{expert_id}.w2.input_scale"] = w2_input - weights[f"{expert_id}.w3.input_scale"] = w3_input + AutoTuner.get().clear_cache() + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) - quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_AWQ) - fused_moe = CutlassFusedMoE( - num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - reduce_results=False, - model_config=ModelConfig(quant_config=quant_config)) - fused_moe.load_weights([weights]) - fused_moe.cuda() + torch.cuda.synchronize() + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref() - def ref(): - results = torch.zeros_like(x) - selected_experts, final_scales = routing_method.apply(router_logits) - unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 - for e_idx in range(NUM_EXPERTS): - mask = selected_experts == e_idx - activated_tokens = mask.sum(1).bool() - act = x[activated_tokens, :] - if act.shape[0] == 0: - continue - final_scale = (final_scales * - mask).sum(1)[activated_tokens].unsqueeze(1) - - # weights - w1 = weights[f"{e_idx}.w1.weight"] - w1 = unpacker(w1.cpu()).T.contiguous().cuda() - w2 = weights[f"{e_idx}.w2.weight"] - w2 = unpacker(w2.cpu()).T.contiguous().cuda() - w3 = weights[f"{e_idx}.w3.weight"] - w3 = unpacker(w3.cpu()).T.contiguous().cuda() - w3_w1 = torch.cat([w3, w1], dim=-1) - - # scales - s1 = weights[f"{e_idx}.w1.weight_scale_inv"].T.contiguous().cuda() - s2 = weights[f"{e_idx}.w2.weight_scale_inv"].T.contiguous().cuda() - s3 = weights[f"{e_idx}.w3.weight_scale_inv"].T.contiguous().cuda() - s3_s1 = torch.cat([s3, s1], dim=-1) - - # prequant / alpha - p1 = weights[f"{e_idx}.w1.input_scale"].cuda() - p2 = weights[f"{e_idx}.w2.input_scale"].cuda() - p3 = weights[f"{e_idx}.w3.input_scale"].cuda() - p3_p1 = max(p1, p3) - - act = torch.clamp((act / p3_p1), -448.0, - 448.0).to(torch.float8_e4m3fn).to(dtype) - w3_w1 = (w3_w1.float() * - s3_s1.repeat_interleave(128, dim=0).float()).to(dtype) - fc1 = torch.matmul(act, w3_w1) * p3_p1 - fc1, gate = fc1.chunk(2, dim=-1) - fc1 = fc1 * torch.nn.functional.silu(gate) - - act = torch.clamp((fc1 / p2), -448.0, - 448.0).to(torch.float8_e4m3fn).to(dtype) - w2 = (w2.float() * - s2.repeat_interleave(128, dim=0).float()).to(dtype) - fc2 = torch.matmul(act, w2) * p2 - results[activated_tokens, :] += (fc2 * final_scale).to( - results.dtype) - return results - - AutoTuner.get().clear_cache() - with torch.inference_mode(), autotune(): - fused_moe.forward(x, router_logits) - - torch.cuda.synchronize() - with torch.inference_mode(): - output = fused_moe.forward(x, router_logits) - ref_output = ref() - - # compare - torch.cuda.synchronize() - torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + # compare + torch.cuda.synchronize() + atol = calc_woq_tolerence(ref_output, weight_dtype) + torch.testing.assert_close(output, ref_output, rtol=1e-7, atol=atol) class RefGatedMLPFusedMoE(nn.Module): @@ -991,12 +2025,15 @@ class RefGatedMLPFusedMoE(nn.Module): hidden_size: int, intermediate_size: int, dtype: Optional[torch.dtype] = None, - model_config: ModelConfig = ModelConfig()): + model_config: ModelConfig = ModelConfig(), + use_cute_dsl_blockscaling_mm: bool = False, + bias=False): super().__init__() self.num_experts = num_experts self.routing_method = routing_method self.hidden_size = hidden_size self.intermediate_size = intermediate_size + self.bias = bias self.dtype = dtype self.quant_config = model_config.quant_config @@ -1005,9 +2042,10 @@ class RefGatedMLPFusedMoE(nn.Module): GatedMLP( hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, - bias=False, + bias=bias, dtype=self.dtype, config=model_config, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, ) for _ in range(self.num_experts) ]) @@ -1047,6 +2085,10 @@ class RefGatedMLPFusedMoE(nn.Module): gate_up_proj_weights[0]['weight'] = weights[f"{expert}.w1.weight"] gate_up_proj_weights[1]['weight'] = weights[f"{expert}.w3.weight"] down_proj_weights[0]['weight'] = weights[f"{expert}.w2.weight"] + if self.bias: + gate_up_proj_weights[0]['bias'] = weights[f"{expert}.w1.bias"] + gate_up_proj_weights[1]['bias'] = weights[f"{expert}.w3.bias"] + down_proj_weights[0]['bias'] = weights[f"{expert}.w2.bias"] if self.quant_config and self.quant_config.quant_algo == QuantAlgo.FP8: gate_up_proj_weights[0]['weight_scale'] = weights[ @@ -1088,6 +2130,13 @@ class RefGatedMLPFusedMoE(nn.Module): f"{expert}.w3.weight_scale"] down_proj_weights[0]["weight_scale"] = weights[ f"{expert}.w2.weight_scale"] + elif self.quant_config and self.quant_config.quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: + gate_up_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]['weight_scale'] = weights[ + f"{expert}.w3.weight_scale"] + down_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w2.weight_scale"] self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights) self.experts[expert].down_proj.load_weights(down_proj_weights) diff --git a/tests/unittest/_torch/modules/test_moe_load_balancer.py b/tests/unittest/_torch/modules/test_moe_load_balancer.py index caad476484..969fb04d03 100644 --- a/tests/unittest/_torch/modules/test_moe_load_balancer.py +++ b/tests/unittest/_torch/modules/test_moe_load_balancer.py @@ -222,22 +222,18 @@ class TestMoeLoadBalancer(unittest.TestCase): # wait_for_gpu_stage mock_wait.return_value = torch.tensor([1]) - layer.wait_for_gpu_stage() + layer.start_wait_gpu_stage() + layer.done_wait_gpu_stage() result = layer.statistic_flag_tensor mock_wait.assert_called_once_with( mock_single_layer_impl.get_pointer()) self.assertEqual(result, mock_wait.return_value) - # set_cpu_stage - layer.set_cpu_stage() - mock_set_cpu.assert_called_once_with( - mock_single_layer_impl.get_pointer()) - # statistic mock_expert_ids = torch.tensor([[0, 1], [2, 3]]) mock_enabled = torch.tensor([1]) layer.statistic_flag_tensor = mock_enabled - layer.statistic(mock_expert_ids, True, False) + layer.update_statistic_with_global_ids(mock_expert_ids, True, False) mock_statistic.assert_called_once_with( mock_expert_ids, mock_enabled, mock_single_layer_impl.get_pointer(), True, False) @@ -248,6 +244,12 @@ class TestMoeLoadBalancer(unittest.TestCase): result = layer.route(mock_selected_experts) assert torch.equal(result, mock_route.return_value) + # set_cpu_stage + layer.start_set_cpu_stage() + layer.done_set_cpu_stage() + mock_set_cpu.assert_called_once_with( + mock_single_layer_impl.get_pointer()) + @patch('tensorrt_llm.bindings.internal.runtime.MoeLoadBalancer') def test_moe_load_balancer_lifecycle_methods(self, mock_load_balancer_impl): """Test lifecycle methods of MoeLoadBalancer.""" @@ -267,7 +269,7 @@ class TestMoeLoadBalancer(unittest.TestCase): mock_load_balancer_impl.return_value.set_warm_up_iter_count.assert_called_once_with( 10) - balancer.set_next_iter_info(True, True) + balancer.set_iter_info(True, True) with MoeLoadBalancerIterContext(balancer): mock_load_balancer_impl.return_value.start_iter.assert_called_once_with( @@ -306,7 +308,7 @@ class TestMoeLoadBalancer(unittest.TestCase): balancer.finalize_model() # enable statistic, disable weight update - balancer.set_next_iter_info(True, False) + balancer.set_iter_info(True, False) # Create sample token data - each token selects 2 experts # 4 tokens, each selecting 2 experts @@ -323,13 +325,16 @@ class TestMoeLoadBalancer(unittest.TestCase): try: with MoeLoadBalancerIterContext(balancer): # Wait for GPU stage and get enabled flag - layer.wait_for_gpu_stage() + layer.start_wait_gpu_stage() + layer.done_wait_gpu_stage() # Run statistic - just test it runs without error - layer.statistic(gathered_raw_expert_ids, True, True) + layer.update_statistic_with_global_ids(gathered_raw_expert_ids, + True, True) # Set CPU stage to signal completion - layer.set_cpu_stage() + layer.start_set_cpu_stage() + layer.done_set_cpu_stage() # Test passed if we got here without exceptions self.assertTrue(True, "Statistic kernel ran successfully") @@ -368,7 +373,7 @@ class TestMoeLoadBalancer(unittest.TestCase): balancer.finalize_model() # enable statistic, disable weight update - balancer.set_next_iter_info(True, False) + balancer.set_iter_info(True, False) # Create sample token data - tokens selecting different experts token_selected_experts = torch.tensor( @@ -384,13 +389,15 @@ class TestMoeLoadBalancer(unittest.TestCase): try: with MoeLoadBalancerIterContext(balancer): # Wait for GPU stage - layer.wait_for_gpu_stage() + layer.start_wait_gpu_stage() + layer.done_wait_gpu_stage() # Run routing routed_slots = layer.route(token_selected_experts) # Set CPU stage - layer.set_cpu_stage() + layer.start_set_cpu_stage() + layer.done_set_cpu_stage() # Verify results - with our initial assignment, expert i should map to slot i expected_slots = torch.tensor( diff --git a/tests/unittest/_torch/modules/test_moe_routing.py b/tests/unittest/_torch/modules/test_moe_routing.py index 53a6e0992b..405ef0299f 100644 --- a/tests/unittest/_torch/modules/test_moe_routing.py +++ b/tests/unittest/_torch/modules/test_moe_routing.py @@ -2,11 +2,10 @@ import pytest import torch import torch.nn.functional as F -from tensorrt_llm._torch.modules.fused_moe import (DefaultMoeRoutingMethod, - LoadBalancedMoeRoutingMethod, - RenormalizeMoeRoutingMethod, - SparseMixerMoeRoutingMethod, - StaticMoeRoutingMethod) +from tensorrt_llm._torch.modules.fused_moe import ( + DefaultMoeRoutingMethod, LoadBalancedMoeRoutingMethod, + RenormalizeMoeRoutingMethod, SparseMixerMoeRoutingMethod, + StaticMoeRoutingMethod, create_renormalize_expert_load_balanced_logits) # Test DefaultMoeRoutingMethod with different top_k values @@ -169,34 +168,96 @@ def test_load_balanced_moe_routing(): def test_static_moe_routing(): routing = StaticMoeRoutingMethod( torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda()) - assert routing.experts_per_token == 4 + with torch.device('cpu'): + assert routing.experts_per_token == 4 - logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], - dtype=torch.float32).cuda() + logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], + dtype=torch.float32).cuda() + indices, scales = routing.apply(logits) + indices = indices.cpu() + + assert scales is None + assert indices.shape == (2, 4) + assert indices.dtype == torch.int32 + + assert torch.equal( + indices, + torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32)) + + routing = StaticMoeRoutingMethod( + torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], + dtype=torch.int32).cuda(), + torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], + dtype=torch.float32).cuda()) + indices, scales = routing.apply(logits) + scales = scales.cpu() + + assert scales is not None + assert scales.shape == (2, 4) + assert scales.dtype == torch.float32 + assert torch.equal( + scales, + torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], + dtype=torch.float32)) + + +@pytest.mark.parametrize( + "num_tokens,expected_assignments,description", + [(3, [2, 2, 1, 1], + "3 tokens - slight imbalance due to total work not divisible by EP size"), + (4, [2, 2, 2, 2], "4 tokens - perfect balance"), + (32, [16, 16, 16, 16], "32 tokens - large batch with perfect balance")]) +def test_renormalize_expert_load_balanced_logits(num_tokens, + expected_assignments, + description): + """Test GPU load balancing with RenormalizeMoeRoutingMethod across different token counts.""" + # Test parameters (consistent across all test cases) + num_experts = 8 + experts_per_token = 2 + moe_ep_size = 4 + device = torch.device('cuda') + + # Generate expert load balanced logits using the utility function directly + logits = create_renormalize_expert_load_balanced_logits( + num_tokens=num_tokens, + num_experts=num_experts, + experts_per_token=experts_per_token, + moe_ep_size=moe_ep_size, + device=device, + dtype=torch.float32) + + # Use RenormalizeMoeRoutingMethod to get expert assignments + routing = RenormalizeMoeRoutingMethod(top_k=experts_per_token) indices, scales = routing.apply(logits) - indices = indices.cpu() - assert scales is None - assert indices.shape == (2, 4) - assert indices.dtype == torch.int32 + # Verify shapes + assert indices.shape == (num_tokens, experts_per_token) + assert scales.shape == (num_tokens, experts_per_token) - assert torch.equal( - indices, torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32)) + # Count expert assignments per GPU + # GPU 0: experts [0, 1], GPU 1: experts [2, 3], GPU 2: experts [4, 5], GPU 3: experts [6, 7] + gpu_assignments = [0, 0, 0, 0] # Count for each GPU + experts_per_gpu = num_experts // moe_ep_size # = 2 - routing = StaticMoeRoutingMethod( - torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda(), - torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], - dtype=torch.float32).cuda()) - indices, scales = routing.apply(logits) - scales = scales.cpu() + indices_flat = indices.view(-1).cpu() + for expert_idx in indices_flat: + gpu_id = expert_idx.item() // experts_per_gpu + gpu_assignments[gpu_id] += 1 - assert scales is not None - assert scales.shape == (2, 4) - assert scales.dtype == torch.float32 - assert torch.equal( - scales, - torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], - dtype=torch.float32)) + # Verify total assignments and expected distribution + total_expected = num_tokens * experts_per_token + assert sum( + gpu_assignments + ) == total_expected, f"Total assignments mismatch for {description}" + + # For cases where perfect balance isn't possible, check sorted equality + # For perfect balance cases, check exact equality + if len(set(expected_assignments) + ) == 1: # All values are the same (perfect balance) + assert gpu_assignments == expected_assignments, f"Perfect balance expected for {description}" + else: # Slight imbalance expected + assert sorted(gpu_assignments) == sorted( + expected_assignments), f"Load balance failed for {description}" if __name__ == '__main__': diff --git a/tests/unittest/_torch/modules/test_triton_linear.py b/tests/unittest/_torch/modules/test_triton_linear.py new file mode 100644 index 0000000000..2d5e87ae68 --- /dev/null +++ b/tests/unittest/_torch/modules/test_triton_linear.py @@ -0,0 +1,192 @@ +import pickle +import sys + +import cloudpickle +import pytest +import torch +from mpi4py import MPI +from utils.util import check_accuracy, skip_pre_hopper + +from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ + IS_TRITON_KERNELS_AVAILABLE +from tensorrt_llm._torch.modules.linear import Linear +from tensorrt_llm._torch.modules.triton_linear import TritonLinear +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig + +cloudpickle.register_pickle_by_value(sys.modules[__name__]) +MPI.pickle.__init__( + cloudpickle.dumps, + cloudpickle.loads, + pickle.HIGHEST_PROTOCOL, +) + + +@pytest.mark.parametrize("linear_cls", [Linear, TritonLinear]) +def test_linear_unquantized(linear_cls): + if not IS_TRITON_KERNELS_AVAILABLE and linear_cls is TritonLinear: + pytest.skip("Triton kernels are not available") + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + num_tokens = 128 + hidden_size = 64 + out_size = 256 + dtype = torch.bfloat16 + x = torch.randn((num_tokens, hidden_size), dtype=dtype).cuda() + w = torch.randn((hidden_size, out_size), dtype=dtype).cuda() + b = torch.randn((out_size, ), dtype=dtype).cuda() + + weights = { + "weight": w.T, # Transpose to match TRT-LLM's weight shape + "bias": b, + } + + linear = linear_cls( + in_features=hidden_size, + out_features=out_size, + bias=True, + dtype=dtype, + ) + linear.load_weights([weights]) + linear.cuda() + + actual_c = linear.forward(x) + reference_c = torch.matmul(x, w) + b + + check_accuracy(actual_c, reference_c, atol=0.01, rtol=0.01, percent=0.99) + + +@pytest.mark.parametrize("linear_cls", [Linear, TritonLinear]) +def test_linear_fp8qdq(linear_cls): + if not IS_TRITON_KERNELS_AVAILABLE and linear_cls is TritonLinear: + pytest.skip("Triton kernels are not available") + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + num_tokens = 128 + hidden_size = 64 + out_size = 256 + dtype = torch.bfloat16 + x = torch.randn((num_tokens, hidden_size), dtype=dtype).cuda() + qx, sx = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x) + w = torch.randn((hidden_size, out_size), dtype=dtype).cuda() + qw, sw = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(w) + b = torch.randn((out_size, ), dtype=dtype).cuda() + + weights = { + "weight": qw.T, # Transpose to match TRT-LLM's weight shape + "bias": b, + "input_scale": sx, + "weight_scale": sw, + } + + linear = linear_cls(in_features=hidden_size, + out_features=out_size, + bias=True, + dtype=dtype, + quant_config=QuantConfig(quant_algo=QuantAlgo.FP8)) + linear.load_weights([weights]) + linear.cuda() + + actual_c = linear.forward(qx) + x_qdq = qx.to(torch.float32) * sx + w_qdq = qw.to(torch.float32) * sw + reference_c = torch.matmul(x_qdq, w_qdq) + b + + check_accuracy(actual_c, + reference_c.to(dtype), + atol=0.01, + rtol=0.01, + percent=0.99) + + +@skip_pre_hopper +@pytest.mark.parametrize("activation_dtype", + [torch.bfloat16, torch.float8_e4m3fn]) +def test_linear_mxfp4(activation_dtype): + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + if torch.cuda.get_device_capability( + )[0] < 10 and activation_dtype == torch.float8_e4m3fn: + pytest.skip("Latest Triton requires BF16 activation on Hopper") + if torch.cuda.get_device_capability( + )[0] >= 10 and activation_dtype == torch.bfloat16: + pytest.skip("Latest Triton requires FP8 activation on Blackwell") + + dtype = torch.bfloat16 + num_tokens = 128 + hidden_size = 256 # Must be even and divisible by 32 for MXFP4 + out_size = 512 + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((num_tokens, hidden_size), dtype=dtype).cuda() + w = torch.randn((hidden_size, out_size), dtype=dtype).cuda() + b = torch.randn((out_size, ), dtype=dtype).cuda() + + from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp_torch, + upcast_from_mxfp_torch) + + def fp32_to_mxfp4(tensor): + # tensor (in_features, out_features) + tensor = tensor.unsqueeze(0) + tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor, + torch.uint8, + axis=1) + return tensor_fp4[0], tensor_scales[0] + + def mxfp4_to_fp32(tensor, scales): + tensor = tensor.unsqueeze(0) + scales = scales.unsqueeze(0) + tensor = upcast_from_mxfp_torch(tensor, scales, torch.float32, axis=1) + return tensor[0] + + # Convert weight to MXFP4 + w_weight_fp4, w_weight_scale = fp32_to_mxfp4(w) + w_weight_qdq = mxfp4_to_fp32(w_weight_fp4, w_weight_scale) + + # Create reference linear with dequantized weights + ref_weights = { + "weight": w_weight_qdq.T, # Transpose to match TRT-LLM's weight shape + "bias": b, + } + + ref_linear = Linear( # Always use regular Linear for reference + in_features=hidden_size, + out_features=out_size, + bias=True, + dtype=dtype, + ) + ref_linear.load_weights([ref_weights]) + ref_linear.cuda() + + ref_output = ref_linear.forward(x) + torch.cuda.synchronize() + + # Now test with MXFP4 quantized weights + weights = { + "weight": w_weight_fp4.T, # Transpose to match TRT-LLM's weight shape + "bias": b, + "weight_scale": + w_weight_scale.T, # Transpose scale to match weight shape + } + + # Add input scale for FP8 activation + if activation_dtype == torch.float8_e4m3fn: + _, input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x) + weights["input_scale"] = input_scale + + quant_algo = QuantAlgo.W4A8_MXFP4_FP8 if activation_dtype == torch.float8_e4m3fn else QuantAlgo.W4A16_MXFP4 + + linear = TritonLinear(in_features=hidden_size, + out_features=out_size, + bias=True, + dtype=dtype, + quant_config=QuantConfig(quant_algo=quant_algo)) + linear.load_weights([weights]) + linear.cuda() + + output = linear.forward(x) + torch.cuda.synchronize() + + # Compare outputs with more relaxed tolerance for MXFP4 + check_accuracy(output, ref_output, rtol=0.2, atol=0.2, percent=0.95) diff --git a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py index e3d00f4683..7ccbc50d7b 100644 --- a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py @@ -21,9 +21,9 @@ import pytest import torch from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor -from utils.util import skip_pre_blackwell import tensorrt_llm +import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams) from tensorrt_llm.functional import AllReduceStrategy @@ -55,6 +55,7 @@ def run_single_rank( dtype, fused_add_norm, reference_output_list, + strategy, ): rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(rank) @@ -70,6 +71,7 @@ def run_single_rank( rank, fused_add_norm, reference_output_list, + strategy, ) except Exception: traceback.print_exc() @@ -89,6 +91,7 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_rank: int, fusion: bool, reference_output_list: list[tuple[torch.Tensor, ...]], + strategy: AllReduceStrategy, ): # Move all tensors to GPU @@ -100,6 +103,12 @@ def row_linear_residual_norm_fusion_forward( for ref_output in reference_output_list ] + if strategy == AllReduceStrategy.NCCL_SYMMETRIC: + ub.initialize_userbuffers_manager( + tensor_parallel_size, 1, 1, tensor_parallel_rank, + torch.cuda.device_count(), + x_list[0].nelement() * x_list[0].element_size(), True) + MPI.COMM_WORLD.barrier() # Create a single AllReduce instance to be reused for all sequence lengths @@ -109,7 +118,7 @@ def row_linear_residual_norm_fusion_forward( tp_size=tensor_parallel_size, rank=tensor_parallel_rank, ), - strategy=AllReduceStrategy.MNNVL, + strategy=strategy, dtype=dtype, ) @@ -152,31 +161,28 @@ def row_linear_residual_norm_fusion_forward( ) -@skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test") @pytest.mark.parametrize( "seq_len", - [ - [1], - [4], - [15], - [32], - [128], - [31, 11, 27, 4], - ], + [[1], [4], [15], [32], [128], [31, 11, 27, 4], [998] + ], # Test for max_num_token fallback ids=lambda x: f"seqlen:{x}", ) -@pytest.mark.parametrize("hidden_size", [7168], ids=lambda x: f"hidden:{x}") -@pytest.mark.parametrize("dtype", - [torch.float16, torch.bfloat16, torch.float32], +@pytest.mark.parametrize("hidden_size", [2880, 7168], + ids=lambda x: f"hidden:{x}") +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=lambda x: f"dtype:{torch.finfo(x).dtype}") +@pytest.mark.parametrize( + "strategy", [AllReduceStrategy.MNNVL, AllReduceStrategy.NCCL_SYMMETRIC], + ids=lambda x: f"strategy:{x}") @pytest.mark.parametrize( "fusion", [True, False], ids=["fusion", "no_fusion"], ) -def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion): +def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, strategy, + fusion): torch.manual_seed(42) tensor_parallel_size = 2 @@ -222,6 +228,7 @@ def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion): dtype, fusion, reference_output_list, + strategy, ) for i in range(tensor_parallel_size) ]), ) diff --git a/tests/unittest/_torch/multi_gpu/test_star_attention.py b/tests/unittest/_torch/multi_gpu/test_star_attention.py index 89f8521b12..abad54e6bc 100644 --- a/tests/unittest/_torch/multi_gpu/test_star_attention.py +++ b/tests/unittest/_torch/multi_gpu/test_star_attention.py @@ -8,6 +8,7 @@ from utils.llm_data import llm_models_root from tensorrt_llm import LLM, SamplingParams from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi.utils import get_total_gpu_memory +from tensorrt_llm.mapping import CpType from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig MAX_SEQ_LEN = 4096 + 1024 @@ -54,7 +55,7 @@ def test_model(backend, model_name, quant, sp_size, sa_block_size, model_dir = str(llm_models_root() / model_name) cp_config = { - "cp_type": "star_attention", + "cp_type": CpType.STAR, "cp_anchor_size": sa_anchor_size, "block_size": sa_block_size } diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index 601f5acfbc..340b2ea628 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -35,7 +35,8 @@ pytestmark = pytest.mark.threadleak(enabled=False) def init_userbuffers_allocator(tp_size, rank, max_ub_size): ub.initialize_userbuffers_manager(tp_size, 1, 1, rank, - torch.cuda.device_count(), max_ub_size) + torch.cuda.device_count(), max_ub_size, + False) def create_userbuffers_tensor(shape, dtype): @@ -977,7 +978,7 @@ def run_single_rank_ub_pass_fp4( def block_scale_unswizzled(scale): sz = fp4_utils.pad_up(hidden_size, 128) - return torch.ops.trtllm.nvfp4_block_scale_interleave_reverse( + return torch.ops.trtllm.block_scale_interleave_reverse( scale.cpu().view(sz, -1))[0:hidden_size] l0_weight_scale_block_unswizzled = block_scale_unswizzled( diff --git a/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py b/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py index 658ec64fb5..5c374d0f2a 100644 --- a/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py +++ b/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py @@ -1,7 +1,6 @@ from difflib import SequenceMatcher import pytest -import torch from utils.llm_data import llm_models_root from tensorrt_llm import LLM, SamplingParams @@ -44,17 +43,19 @@ def test_llama4(model_name, backend, tp_size, use_cuda_graph, "This is a very long prompt to exercise long context. Count up to 10000 from 1, 2, 3," + ", ".join(str(i) for i in range(4, 9000)) }, - { - "prompt": "<|image|>This image is of color", - "multi_modal_data": { - "image": [torch.ones(3, 1024, 1024)] - } - }, + # TODO: Fix multimodal test. + # { + # "prompt": "<|image|>This image is of color", + # "multi_modal_data": { + # "image": [torch.ones(3, 1024, 1024)] + # } + # }, ] expected_outputs = [ - " the head of state and head of government of the", ", 8999, 9000, ", - " white. What is the color of the background of" + " the head of state and head of government of the", + ", 9000, 9001, ", + # " white. What is the color of the background of" # TODO: Fix multimodal test. ] pytorch_config = dict(attn_backend=backend) diff --git a/tests/unittest/_torch/multimodal/test_share_multiparams.py b/tests/unittest/_torch/multimodal/test_share_multiparams.py index d4ce40f633..343c2f5372 100644 --- a/tests/unittest/_torch/multimodal/test_share_multiparams.py +++ b/tests/unittest/_torch/multimodal/test_share_multiparams.py @@ -39,14 +39,20 @@ class TestMultimodalParamsHandleConversion(unittest.TestCase): params.to_handle("multimodal_data") self.assertEqual(params.multimodal_data, {}) + def test_to_handle_unsupported_element(self): + """Test to_handle raises ValueError for unsupported elements.""" params = MultimodalParams() multimodal_input = MultimodalInput( multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]] * 2, multimodal_positions=[0, 10], multimodal_lengths=[2, 2]) params.multimodal_input = multimodal_input - params.to_handle("multimodal_input") - self.assertEqual(params.multimodal_input, multimodal_input) + + with self.assertRaises(ValueError) as context: + params.to_handle("multimodal_input") + + self.assertIn("Unsupported element 'multimodal_input'", + str(context.exception)) def test_to_tensor_basic_handle(self): """Test converting a basic handle back to tensor.""" @@ -54,9 +60,9 @@ class TestMultimodalParamsHandleConversion(unittest.TestCase): params.multimodal_data = {"multimodal_embedding": self.mm_embedding} # Convert to handle - params.to_handle("multimodal_data", key="multimodal_embedding") + params.to_handle("multimodal_data") # Convert back to tensor - params.to_tensor("multimodal_data", key="multimodal_embedding") + params.to_tensor("multimodal_data") result = params.multimodal_data["multimodal_embedding"] self.assertIsInstance(result, torch.Tensor) @@ -67,8 +73,8 @@ class TestMultimodalParamsHandleConversion(unittest.TestCase): params = MultimodalParams() params.multimodal_data = self.sample_multimodal_data.copy() - params.to_handle("multimodal_data", key=None) - params.to_tensor("multimodal_data", key=None) + params.to_handle("multimodal_data") + params.to_tensor("multimodal_data") self.assertTrue( torch.allclose(params.multimodal_data["multimodal_embedding"], @@ -90,5 +96,56 @@ class TestMultimodalParamsHandleConversion(unittest.TestCase): self.image["image_width"]) +class TestMultimodalParamsDeviceTransfer(unittest.TestCase): + """Test cases for to_device method in MultimodalParams.""" + + def setUp(self): + """Set up test fixtures.""" + self.mm_embedding = torch.randn(3, 4, 5) + self.mrope_config = { + "mrope_rotary_cos_sin": torch.randn(2, 3), + "mrope_position_deltas": torch.randn(5), + } + self.image = { + "pixel_values": torch.randn(1, 3, 224, 224), + "image_height": [224], + "image_width": [224], + } + self.sample_multimodal_data = { + "multimodal_embedding": self.mm_embedding, + "mrope_config": self.mrope_config, + "image": self.image, + } + + def test_to_device_basic(self): + """Test converting a basic data to device.""" + params = MultimodalParams() + params.multimodal_data = {"multimodal_embedding": self.mm_embedding} + + params.to_device("multimodal_data", device="cuda:0", pin_memory=True) + + result = params.multimodal_data["multimodal_embedding"] + self.assertEqual(result.device, torch.device("cuda:0")) + + def test_to_device_all_data(self): + """Test converting all data to device.""" + params = MultimodalParams() + params.multimodal_data = self.sample_multimodal_data.copy() + + params.to_device("multimodal_data", device="cuda:0", pin_memory=True) + + result = params.multimodal_data["multimodal_embedding"] + self.assertEqual(result.device, torch.device("cuda:0")) + + result = params.multimodal_data["mrope_config"]["mrope_rotary_cos_sin"] + self.assertEqual(result.device, torch.device("cuda:0")) + + result = params.multimodal_data["mrope_config"]["mrope_position_deltas"] + self.assertEqual(result.device, torch.device("cuda:0")) + + result = params.multimodal_data["image"]["pixel_values"] + self.assertEqual(result.device, torch.device("cuda:0")) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unittest/_torch/speculative/test_ngram.py b/tests/unittest/_torch/speculative/test_ngram.py index 90bd647156..7ea9a41bac 100644 --- a/tests/unittest/_torch/speculative/test_ngram.py +++ b/tests/unittest/_torch/speculative/test_ngram.py @@ -54,7 +54,7 @@ def test_llama_ngram(disable_overlap_scheduler: bool, use_cuda_graph: bool, "The capital of France is", "The president of the United States is", ] - sampling_params = SamplingParams(max_tokens=32) + sampling_params = SamplingParams(max_tokens=32, ignore_eos=True) llm_spec = LLM(**llm_common_config, speculative_config=spec_config) results_spec = llm_spec.generate(prompts, sampling_params) diff --git a/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py b/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py new file mode 100644 index 0000000000..13960e528b --- /dev/null +++ b/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py @@ -0,0 +1,53 @@ +import unittest + +import numpy as np +import torch +from scipy.stats import entropy + +from tensorrt_llm._torch.pyexecutor.sampler import (get_rejected_indices, + sample_rejected) + + +def test_get_rejected_indices(): + vocab_size = 500 + num_iter = 50000 + draft_probs = torch.rand(1, vocab_size) + drop_idx = torch.topk(draft_probs[0], k=400, largest=False)[1] + draft_probs[0, drop_idx] = 0.0 + draft_probs = draft_probs / draft_probs.sum(dim=-1, keepdim=True) + target_probs = torch.rand(2, vocab_size) + drop_idx = torch.topk(target_probs[0], k=400, largest=False)[1] + target_probs[0, drop_idx] = 0.0 + target_probs = target_probs / target_probs.sum(dim=-1, keepdim=True) + generator = torch.Generator() + sampled_tokens = [] + sampled_regular = [] + for _ in range(num_iter): + draft_tokens = [ + torch.multinomial(draft_probs, num_samples=1, + generator=generator).item() + ] + rejected_indices = get_rejected_indices(draft_probs, target_probs, + generator, draft_tokens) + if rejected_indices.shape[0] == 0: + sampled_tokens.append(draft_tokens[0]) + else: + sampled_tokens.append( + sample_rejected(draft_probs, target_probs, generator, 0).item()) + sampled_regular.append( + torch.multinomial(target_probs[0], + num_samples=1, + generator=generator).item()) + bins = np.arange(vocab_size + 1) - 0.5 # Bins for histogram + sampled_tokens, _ = np.histogram(sampled_tokens, bins=bins, density=True) + sampled_regular, _ = np.histogram(sampled_regular, bins=bins, density=True) + expected_prob = target_probs[0].squeeze().numpy() + + # KL Divergence check + kl_divergence = entropy(expected_prob, sampled_tokens) + kl_divergence_regular = entropy(expected_prob, sampled_regular) + assert abs(kl_divergence - kl_divergence_regular) < 0.01 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/_torch/test_attention_mla.py b/tests/unittest/_torch/test_attention_mla.py index a61975ddf8..41c37031bf 100644 --- a/tests/unittest/_torch/test_attention_mla.py +++ b/tests/unittest/_torch/test_attention_mla.py @@ -388,6 +388,12 @@ def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int], device = torch.device('cuda') dtype = scenario.dtype kv_cache_dtype = scenario.kv_cache_dtype + + FAILED_CSL = [777, 912, 431, 42, 266, 989, 524] + if (kv_cache_dtype is torch.float8_e4m3fn + and context_sequence_lengths == FAILED_CSL): + pytest.skip("https://nvbugs/5453806") + print( f"--------------------------------Test for scenario: {scenario} start--------------------------------" ) diff --git a/tests/unittest/_torch/test_autotuner.py b/tests/unittest/_torch/test_autotuner.py index 21eb0a9626..c2f5c32141 100644 --- a/tests/unittest/_torch/test_autotuner.py +++ b/tests/unittest/_torch/test_autotuner.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import torch @@ -19,14 +19,16 @@ def test_multi_dynamic_dims(): x = torch.rand([5, 1024]) w = torch.rand([7, 19]) dynamic_tensor_specs = ( - DynamicTensorSpec(0, 0, [1, 3, 5], lambda x: x // 2), - DynamicTensorSpec(0, 1, [16, 24, 1024], lambda x: x // 2), + DynamicTensorSpec(0, 0, [1, 3, 5]), + DynamicTensorSpec(0, 1, [16, 24, 1024]), DynamicTensorSpec(1, 1, [3, 7, 9], lambda x: x // 2), ) profiles = tuner._optimization_profiles( tuning_config=TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs), inputs=[x, w]) + # choice(0, 0) * choice(0, 1) * choice(1, 1) + # 3 * 3 * 3 = 27, because 19 is mapped to 9 and already inside the bucket assert len(profiles) == 27 sample_0 = OptimizationProfile(shapes=[[ DynamicDim(min=1, opt=1, max=3), @@ -90,7 +92,7 @@ def check_gemm_tactic_valid(tactic: int, m: int) -> bool: class GemmRunner(TunableRunner): def get_valid_tactics(self, inputs: List[FakeTensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[int]: # The simulated delay is not deterministic, so we need to return specific tactics here return [-1, 0, 1] @@ -98,7 +100,8 @@ class GemmRunner(TunableRunner): /, inputs: List[torch.Tensor], *, - tactic: int = -1) -> torch.Tensor: + tactic: int = -1, + **kwargs) -> torch.Tensor: assert tactic in [-1, 0, 1] return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs) @@ -258,14 +261,18 @@ def test_multiple_runners_different_attributes(): # Verify different cache keys are generated shapes = (x.shape, w.shape) - cache_key_0 = tuner._get_cache_key(custom_op="test_multiple_runners", - input_shapes=shapes, - runner=runner_0, - tuning_config=tuning_config) - cache_key_1 = tuner._get_cache_key(custom_op="test_multiple_runners", - input_shapes=shapes, - runner=runner_1, - tuning_config=tuning_config) + cache_key_0 = tuner._get_cache_key( + custom_op="test_multiple_runners", + input_shapes=shapes, + runner=runner_0, + tuning_config=tuning_config, + ) + cache_key_1 = tuner._get_cache_key( + custom_op="test_multiple_runners", + input_shapes=shapes, + runner=runner_1, + tuning_config=tuning_config, + ) assert cache_key_0 != cache_key_1, "Runners with different attributes should have different cache keys" @@ -301,3 +308,47 @@ def test_multiple_dynamic_shapes_cache(): ] assert len(cache_entries) == 12, \ f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}" + + +class GemmRunnerWithTacticConfigs(TunableRunner): + valid_tactic_ids = [-1, 0, 1] + + def get_valid_tactics( + self, + inputs: List[FakeTensor], + profile: OptimizationProfile, + ) -> List[Dict[str, int]]: + # The simulated delay is not deterministic, so we need to return specific tactics here + return [{ + "block_size": block_size, + "tactic_id": tactic_id + } for tactic_id in self.valid_tactic_ids for block_size in [128, 256]] + + def forward( + self, + /, + inputs: List[torch.Tensor], + *, + tactic: dict = {}, + ) -> torch.Tensor: + # Notice that in fallback case tactic is -1 + if tactic == -1: + # assign default configs for fallback case + block_size, tactic_id = 128, -1 + else: + block_size, tactic_id = tactic["block_size"], tactic["tactic_id"] + assert tactic_id in self.valid_tactic_ids + return [gemm_0, gemm_1, gemm_fallback][tactic_id](*inputs) + + +def test_autotuner_tactic_configs(): + runner_0 = GemmRunnerWithTacticConfigs() + runners = [runner_0] + x, w = torch.randn(64, 64), torch.randn(64, 128) + tuning_config = TuningConfig() + with autotune(): + tuner = AutoTuner.get() + runner, tactic = tuner.choose_one("test_autotuner_tactic_configs", + runners, tuning_config, [x, w]) + + runner_0.forward(inputs=[x, w], tactic=tactic) diff --git a/tests/unittest/_torch/test_beam_search.py b/tests/unittest/_torch/test_beam_search.py index 1b417ef284..1eb8cf350c 100644 --- a/tests/unittest/_torch/test_beam_search.py +++ b/tests/unittest/_torch/test_beam_search.py @@ -43,7 +43,6 @@ def llm(fixed_params, input_prompts): input_prompts ), # use small batch size to prevent large buffers from possibly hiding wrong data accesses. max_seq_len=32, - enable_trtllm_sampler=True, max_beam_width=fixed_params["max_beam_width"], disable_overlap_scheduler=True, cuda_graph_config=None, @@ -60,10 +59,10 @@ def llm_cuda_graph(fixed_params, input_prompts): input_prompts ), # use small batch size to prevent large buffers from possibly hiding wrong data accesses. max_seq_len=32, - enable_trtllm_sampler=True, max_beam_width=fixed_params["max_beam_width"], disable_overlap_scheduler=False, - cuda_graph_config=CudaGraphConfig(), + cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8], + enable_padding=True), ) @@ -128,7 +127,7 @@ def test_beam_search_output_shapes(gather_context_logits: bool, @pytest.mark.parametrize("gather_generation_logits", [True, False]) @pytest.mark.parametrize("gather_context_logits", [True, False]) @pytest.mark.parametrize("num_output_beams", [1, 2]) -@pytest.mark.parametrize("num_prompts", [1, 2]) +@pytest.mark.parametrize("num_prompts", [1, 2, 3]) @pytest.mark.threadleak(enabled=False) def test_beam_search_output_shapes_cuda_graph_and_overlap( gather_context_logits: bool, gather_generation_logits: bool, @@ -147,6 +146,10 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap( return_generation_logits=gather_generation_logits, logprobs=return_log_probs, ) + # test padding of cuda graph with 3 prompts + # replicate the prompts to have more than 2 prompts available + if (num_prompts == 3 and len(input_prompts) == 2): + input_prompts = [input_prompts[0]] * 3 outputs = llm_cuda_graph.generate(input_prompts[:num_prompts], sampling_params=sampling_params) assert len(outputs) == num_prompts diff --git a/tests/unittest/_torch/test_best_of_n.py b/tests/unittest/_torch/test_best_of_n.py index 89653269d6..90890efdd9 100644 --- a/tests/unittest/_torch/test_best_of_n.py +++ b/tests/unittest/_torch/test_best_of_n.py @@ -39,7 +39,6 @@ def llm(): kv_cache_config=KvCacheConfig(max_tokens=1000), max_batch_size=8, max_seq_len=64, - enable_trtllm_sampler=True, disable_overlap_scheduler=True) diff --git a/tests/unittest/_torch/test_custom_ops.py b/tests/unittest/_torch/test_custom_ops.py index 9a50468ae3..18c0933062 100644 --- a/tests/unittest/_torch/test_custom_ops.py +++ b/tests/unittest/_torch/test_custom_ops.py @@ -91,6 +91,10 @@ def test_register_fake(custom_ops): "trtllm::set_chunked_kv_cache_for_mla", "trtllm::mla_rope_append_paged_kv_assign_q", "trtllm::fused_qk_norm_rope", + "trtllm::bf16_mxe2m1_block_scale_moe_runner", + "trtllm::e4m3_mxe2m1_block_scale_moe_runner", + "trtllm::mxe4m3_mxe2m1_block_scale_moe_runner", + "trtllm::mxfp8_quantize", } ops_missing_fake_impl = [] diff --git a/tests/unittest/_torch/test_flashinfer_star_attn.py b/tests/unittest/_torch/test_flashinfer_star_attn.py index 7bad00724c..ef19d2e3cd 100644 --- a/tests/unittest/_torch/test_flashinfer_star_attn.py +++ b/tests/unittest/_torch/test_flashinfer_star_attn.py @@ -13,7 +13,7 @@ from tensorrt_llm._torch.attention_backend import (StarAttention, from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig -from tensorrt_llm.mapping import Mapping +from tensorrt_llm.mapping import CpType, Mapping class TestingStarAttentionMetadata(StarAttentionMetadata): @@ -144,7 +144,7 @@ class TestStarAttention(unittest.TestCase): tokens_per_block = 64 max_seq_len = tokens_per_block * num_blocks cp_config = { - "cp_type": "star_attention", + "cp_type": CpType.STAR, "cp_anchor_size": scenario.anchor_size, "block_size": scenario.block_size } @@ -579,7 +579,7 @@ class TestStarAttention(unittest.TestCase): max_seq_len = tokens_per_block * num_blocks num_layers = 1 if isinstance(num_kv_heads, int) else len(num_kv_heads) cp_config = { - "cp_type": "star_attention", + "cp_type": CpType.STAR, "cp_anchor_size": test_scenario.anchor_size, "block_size": test_scenario.block_size } diff --git a/tests/unittest/_torch/test_overlap_scheduler.py b/tests/unittest/_torch/test_overlap_scheduler.py index 7321503a58..8d7406aacc 100644 --- a/tests/unittest/_torch/test_overlap_scheduler.py +++ b/tests/unittest/_torch/test_overlap_scheduler.py @@ -21,10 +21,10 @@ def model_path(): return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" -def create_llm(model_dir, disable_overlap_scheduler, enable_trtllm_sampler): +def create_llm(model_dir, disable_overlap_scheduler, sampler_type): """Create LLM with specific overlap scheduler setting""" pytorch_config = dict(disable_overlap_scheduler=disable_overlap_scheduler, - enable_trtllm_sampler=enable_trtllm_sampler) + sampler_type=sampler_type) trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False) @@ -41,16 +41,15 @@ def create_llm(model_dir, disable_overlap_scheduler, enable_trtllm_sampler): ) -@pytest.mark.parametrize("enable_trtllm_sampler", [False, True]) +@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"]) @pytest.mark.high_cuda_memory -def test_overlap_scheduler_consistency(model_path, test_case, - enable_trtllm_sampler): +def test_overlap_scheduler_consistency(model_path, test_case, sampler_type): # Test configuration prompts = test_case["prompts"] max_new_tokens = test_case["max_new_tokens"] temperature = test_case["temperature"] top_p = test_case["top_p"] - stop_words = test_case["stop_words"] if enable_trtllm_sampler else None + stop_words = test_case["stop_words"] sampling_config = SamplingParams(max_tokens=max_new_tokens, stop=stop_words, @@ -62,7 +61,7 @@ def test_overlap_scheduler_consistency(model_path, test_case, # Test with overlap scheduler enabled llm = create_llm(model_path, disable_overlap_scheduler=False, - enable_trtllm_sampler=enable_trtllm_sampler) + sampler_type=sampler_type) outputs_with_overlap = llm.generate(prompts, sampling_params=sampling_config, use_tqdm=True) @@ -74,7 +73,7 @@ def test_overlap_scheduler_consistency(model_path, test_case, # Test with overlap scheduler disabled llm = create_llm(model_path, disable_overlap_scheduler=True, - enable_trtllm_sampler=enable_trtllm_sampler) + sampler_type=sampler_type) outputs_without_overlap = llm.generate(prompts, sampling_params=sampling_config, use_tqdm=True) diff --git a/tests/unittest/_torch/test_resource_manager.py b/tests/unittest/_torch/test_resource_manager.py index da1dae84ba..24320a993b 100644 --- a/tests/unittest/_torch/test_resource_manager.py +++ b/tests/unittest/_torch/test_resource_manager.py @@ -5,11 +5,11 @@ import sys import unittest import numpy as np -import pytest import torch import tensorrt_llm import tensorrt_llm.bindings +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, PeftCacheConfig, PeftCacheManager) @@ -17,6 +17,7 @@ from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp from tensorrt_llm.bindings import executor as tllm from tensorrt_llm.bindings.internal.batch_manager import \ PeftTaskNotCachedException +from tensorrt_llm.lora_helper import LoraConfig DataType = tensorrt_llm.bindings.DataType LoraModule = tensorrt_llm.bindings.LoraModule @@ -247,7 +248,7 @@ class TestResourceManager(unittest.TestCase): lora_config = torch.from_numpy(lora_config) input_tokens = [i + 1 for i in range(max_new_tokens)] - request = tensorrt_llm.bindings.internal.batch_manager.LlmRequest( + request = LlmRequest( request_id=request_id, max_new_tokens=max_new_tokens, input_tokens=input_tokens, @@ -261,15 +262,13 @@ class TestResourceManager(unittest.TestCase): return request def get_lora_data(self): - """Create mock LoRA weights and config that match the C++ validation expectations. + """Create mock LoRA weights and config. Returns: - tuple: (weights tensor, config tensor) formatted correctly for the C++ implementation. + tuple: (weights tensor, config tensor). """ lora_weights = np.load(self.TP1_WEIGHTS_PATH).astype(np.float16) - lora_weights = np.expand_dims(lora_weights, axis=0) lora_config = np.load(self.TP1_CONFIG_PATH) - lora_config = np.expand_dims(lora_config, axis=0) return lora_weights, lora_config def test_successful_mocked_peft_cache_manager_initialization(self): @@ -277,6 +276,7 @@ class TestResourceManager(unittest.TestCase): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -290,6 +290,7 @@ class TestResourceManager(unittest.TestCase): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -307,6 +308,7 @@ class TestResourceManager(unittest.TestCase): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -322,6 +324,7 @@ class TestResourceManager(unittest.TestCase): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -349,13 +352,13 @@ class TestResourceManager(unittest.TestCase): self.assertEqual(len(peft_table), self.num_lora_modules) - @pytest.mark.skip(reason="https://nvbugs/5324252") def test_put_get(self): """Test adding a request with properly configured LoRA weights and config.""" peft_cache_config = self.create_peft_cache_config() peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) diff --git a/tests/unittest/_torch/test_return_logits.py b/tests/unittest/_torch/test_return_logits.py index a9e0b1a430..0d6a5e28ca 100644 --- a/tests/unittest/_torch/test_return_logits.py +++ b/tests/unittest/_torch/test_return_logits.py @@ -16,10 +16,10 @@ global_kvcache_config = KvCacheConfig(max_tokens=10000) @pytest.mark.parametrize("return_log_probs", [False, True]) @pytest.mark.parametrize("gather_generation_logits", [False, True]) @pytest.mark.parametrize("gather_context_logits", [False, True]) -@pytest.mark.parametrize("enable_trtllm_sampler", [False, True]) +@pytest.mark.parametrize("sampler_type", ["TRTLLMSampler", "TorchSampler"]) @pytest.mark.parametrize("disable_overlap_scheduler", [False, True]) def test_generate_with_return_logits(disable_overlap_scheduler: bool, - enable_trtllm_sampler: bool, + sampler_type: str, gather_context_logits: bool, gather_generation_logits: bool, return_log_probs: bool): @@ -27,7 +27,7 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if not enable_trtllm_sampler and gather_context_logits: + if sampler_type == "TorchSampler" and gather_context_logits: pytest.skip("TorchSampler does not support gather_context_logits") build_config = BuildConfig() @@ -41,7 +41,7 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, gather_generation_logits=gather_generation_logits, max_batch_size= 128, # reduce buffer sizes, specially for generation logits - enable_trtllm_sampler=enable_trtllm_sampler, + sampler_type=sampler_type, disable_overlap_scheduler=disable_overlap_scheduler, ) @@ -83,10 +83,10 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, @pytest.mark.parametrize("return_log_probs", [False, True]) @pytest.mark.parametrize("gather_generation_logits", [False, True]) @pytest.mark.parametrize("gather_context_logits", [False, True]) -@pytest.mark.parametrize("enable_trtllm_sampler", [False, True]) +@pytest.mark.parametrize("sampler_type", ["TRTLLMSampler", "TorchSampler"]) @pytest.mark.parametrize("disable_overlap_scheduler", [False, True]) def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, - enable_trtllm_sampler: bool, + sampler_type: str, gather_context_logits: bool, gather_generation_logits: bool, return_log_probs: bool): @@ -94,7 +94,7 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if not enable_trtllm_sampler and gather_context_logits: + if sampler_type == "TorchSampler" and gather_context_logits: pytest.skip("TorchSampler does not support gather_context_logits") build_config = BuildConfig() @@ -108,7 +108,7 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, gather_generation_logits=gather_generation_logits, max_batch_size= 128, # reduce buffer sizes, specially for generation logits - enable_trtllm_sampler=enable_trtllm_sampler, + sampler_type=sampler_type, disable_overlap_scheduler=disable_overlap_scheduler, ) sampling_params = SamplingParams( diff --git a/tests/unittest/_torch/test_trtllm_sampler.py b/tests/unittest/_torch/test_trtllm_sampler.py index 2f3c31bbb8..d2fb0e9e65 100644 --- a/tests/unittest/_torch/test_trtllm_sampler.py +++ b/tests/unittest/_torch/test_trtllm_sampler.py @@ -24,8 +24,6 @@ def model_path(): def create_llm(model_dir): """Create LLM with specific overlap scheduler setting""" - pytorch_config = dict(enable_trtllm_sampler=True) - trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False) return LLM( @@ -34,7 +32,6 @@ def create_llm(model_dir): trust_remote_code=True, enable_chunked_prefill=True, cuda_graph_config=CudaGraphConfig(), - **pytorch_config, kv_cache_config=trt_kv_cache_config, max_num_tokens= 128 # Only one request longer than max_num_tokens is required to test chunked prefill diff --git a/tests/unittest/_torch/thop/test_fp4_bmm_quantize.py b/tests/unittest/_torch/thop/test_fp4_bmm_quantize.py index c05914127c..38edd1f155 100644 --- a/tests/unittest/_torch/thop/test_fp4_bmm_quantize.py +++ b/tests/unittest/_torch/thop/test_fp4_bmm_quantize.py @@ -172,8 +172,8 @@ def test_fp4_sf_interleave(b, m, k): w_cuda = w.cuda() # The cpu and cuda kernels are different - w_out_cpu = torch.ops.trtllm.nvfp4_block_scale_interleave(w) - w_out_cuda = torch.ops.trtllm.nvfp4_block_scale_interleave(w_cuda) + w_out_cpu = torch.ops.trtllm.block_scale_interleave(w) + w_out_cuda = torch.ops.trtllm.block_scale_interleave(w_cuda) torch.cuda.synchronize() torch.testing.assert_allclose(w_out_cpu.cuda(), w_out_cuda) diff --git a/tests/unittest/_torch/thop/test_fp4_gemm_quantize.py b/tests/unittest/_torch/thop/test_fp4_gemm_quantize.py index 75af5d9420..f1faf28109 100644 --- a/tests/unittest/_torch/thop/test_fp4_gemm_quantize.py +++ b/tests/unittest/_torch/thop/test_fp4_gemm_quantize.py @@ -127,27 +127,40 @@ class TestFunctional(unittest.TestCase): c_pt = torch.nn.functional.linear(a_pt, b_pt) self.assertTrue(torch.allclose(c_pt, c, atol=1e-2, rtol=1e-2)) - @parameterized.expand(list([[1024, 1024, torch.half, False], - [2, 512, torch.bfloat16, False], - [13, 16, torch.half, True]]), + @parameterized.expand(list([[1024, 1024, torch.half, False, True], + [2, 512, torch.bfloat16, False, True], + [2, 512, torch.bfloat16, True, True], + [16, 512, torch.half, True, True], + [16, 512, torch.half, False, True], + [16, 512, torch.half, True, False], + [16, 512, torch.half, False, False]]), name_func=unittest_name_func) @skip_pre_blackwell_unittest - def test_fp4_quantize_torch(self, m, k, dtype, use_ue8m0): + def test_fp4_quantize_torch(self, m, k, dtype, use_ue8m0, + is_sf_swizzled_layout): a = torch.randn([m, k], dtype=torch.float32).to(dtype).float() + if use_ue8m0: + # Expand the range of the input to cover more cases + a = a * 16 + a_global_sf = (448 * 6) / a.abs().max().float() - sf_vec_size = 16 + sf_vec_size = 32 if use_ue8m0 else 16 a_fp4, a_sf = torch.ops.trtllm.fp4_quantize( - a.to(dtype).cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0) + a.to(dtype).cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, + is_sf_swizzled_layout) + + sf_type = 0 if use_ue8m0 else 1 a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(), 1 / a_global_sf, - sf_vec_size) + sf_vec_size, sf_type, + is_sf_swizzled_layout) torch.cuda.synchronize() - if not use_ue8m0: - # The gap is too large for ue8m0, so we just make sure that it runs - self.assertTrue(torch.allclose(a_pt, a, atol=1, rtol=0)) + atol = 8 if use_ue8m0 else 1 + rtol = 0 + self.assertTrue(torch.allclose(a_pt, a, atol=atol, rtol=rtol)) @parameterized.expand(list([[2, 16, torch.half, False, True], [2, 16, torch.half, False, False], @@ -157,11 +170,12 @@ class TestFunctional(unittest.TestCase): [1024, 512, torch.bfloat16, True, False]]), name_func=unittest_name_func) @skip_pre_blackwell_unittest - def test_fp4_quantize_torch_different_sf_layot(self, m, k, dtype, use_ue8m0, - is_sf_swizzled_layout): + def test_fp4_quantize_torch_different_sf_layout(self, m, k, dtype, + use_ue8m0, + is_sf_swizzled_layout): a = torch.randn([m, k], dtype=torch.float32).to(dtype).float() a_global_sf = (448 * 6) / a.abs().max().float() - sf_vec_size = 16 + sf_vec_size = 32 if use_ue8m0 else 16 a_fp4, a_sf = torch.ops.trtllm.fp4_quantize( a.to(dtype).cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, @@ -178,7 +192,7 @@ class TestFunctional(unittest.TestCase): self.assertTrue(torch.allclose(a_pt, a, atol=1, rtol=0)) @parameterized.expand(list([[64, 64, torch.float8_e4m3fn, False, True], - [13, 16, torch.float8_e4m3fn, True, True], + [13, 32, torch.float8_e4m3fn, True, True], [3, 48, torch.float8_e4m3fn, False, False], [1024, 1024, torch.float8_e4m3fn, True, False]]), @@ -192,7 +206,7 @@ class TestFunctional(unittest.TestCase): a_fp8 = (a / amax * 448).to(dtype) aq_fp32 = a_fp8.float() * amax / 448 a_global_sf = (448 * 6) / amax - sf_vec_size = 16 + sf_vec_size = 32 if use_ue8m0 else 16 a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a_fp8.cuda(), a_global_sf.cuda(), diff --git a/tests/unittest/_torch/thop/test_fp4_linear.py b/tests/unittest/_torch/thop/test_fp4_linear.py index 19ac7f3ed0..bce3709191 100644 --- a/tests/unittest/_torch/thop/test_fp4_linear.py +++ b/tests/unittest/_torch/thop/test_fp4_linear.py @@ -38,9 +38,8 @@ def test_fp4_linear(dtype): assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype - w_sf_block_unswizzled = ( - torch.ops.trtllm.nvfp4_block_scale_interleave_reverse( - w_sf_block.cpu().view(HIDDEN_SIZE, -1))) + w_sf_block_unswizzled = (torch.ops.trtllm.block_scale_interleave_reverse( + w_sf_block.cpu().view(HIDDEN_SIZE, -1))) l_fp4.load_weights([{ 'input_scale': diff --git a/tests/unittest/_torch/thop/test_fp8_quantize.py b/tests/unittest/_torch/thop/test_fp8_quantize.py index 72608b554c..54b94e6c40 100644 --- a/tests/unittest/_torch/thop/test_fp8_quantize.py +++ b/tests/unittest/_torch/thop/test_fp8_quantize.py @@ -16,7 +16,9 @@ import math import pytest import torch -from utils.util import getSMVersion +from parameterized import parameterized +from utils.util import (getSMVersion, skip_pre_blackwell_unittest, + unittest_name_func) def _dequant_fp8(input, scale, transpose_scale, block_m, block_n): @@ -95,3 +97,102 @@ def test_fp8_quantize_blackwell(dtype, m, k): a.cpu().to(torch.float32), atol=1e-1, rtol=1e-1) + + +def mxfp8_quantize_check_accuracy(a, b, atol, rtol, percent): + if torch.any(torch.isnan(a)): + raise Exception("NaN in a") + if torch.any(torch.isnan(b)): + raise Exception("NaN in b") + assert a.shape == b.shape + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception("Mismatch percentage is %f for rtol %f" % + (mismatch_percent, rtol)) + + +@parameterized.expand(list([[1, 1024, torch.half, True], + [2, 512, torch.bfloat16, True], + [16, 512, torch.half, True], + [16, 512, torch.half, False], + [1024, 512, torch.half, False], + [1024, 512, torch.half, False]]), + name_func=unittest_name_func) +@skip_pre_blackwell_unittest +def test_mxfp8_quantize_torch_host(m, k, dtype, is_sf_swizzled_layout): + torch.random.manual_seed(0) + a = (torch.randn([m, k], dtype=torch.float) * 16).cpu().contiguous() + + a_fp8, a_sf = torch.ops.tensorrt_llm.quantize_mxe4m3_host( + a, is_sf_swizzled_layout) + + a_pt = torch.ops.tensorrt_llm.dequantize_mxe4m3_host( + a_fp8.view(torch.uint8), a_sf.view(torch.uint8), is_sf_swizzled_layout) + + torch.cuda.synchronize() + + mxfp8_quantize_check_accuracy(a_pt, a, 8, 0, 0.999) + + +@parameterized.expand(list([[1, 1024, torch.half, True], + [2, 512, torch.bfloat16, True], + [16, 512, torch.half, True], + [16, 512, torch.half, False], + [1024, 512, torch.half, False], + [1024, 512, torch.half, False]]), + name_func=unittest_name_func) +@skip_pre_blackwell_unittest +def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout): + torch.random.manual_seed(0) + a = (torch.randn([m, k], dtype=torch.float) * + 16).to(dtype).cuda().contiguous() + + # Quantize it on device. + a_fp8, a_sf = torch.ops.trtllm.mxfp8_quantize(a, is_sf_swizzled_layout, 32) + + # Dequantize it on host. + a_pt = torch.ops.tensorrt_llm.dequantize_mxe4m3_host( + a_fp8.cpu().view(torch.uint8), + a_sf.cpu().view(torch.uint8), is_sf_swizzled_layout) + + torch.cuda.synchronize() + + mxfp8_quantize_check_accuracy(a_pt.cpu().to(torch.float32), + a.cpu().to(torch.float32), 8, 0, 0.999) + + +@pytest.mark.parametrize("m", [1, 2, 16, 1024]) +@pytest.mark.parametrize("k", [1568]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +@pytest.mark.parametrize("alignment", [64, 128]) +@skip_pre_blackwell_unittest +def test_mxfp8_quantize_alignment_torch_device(m, k, dtype, + is_sf_swizzled_layout, + alignment): + torch.random.manual_seed(0) + a = (torch.randn([m, k], dtype=torch.float) * + 16).to(dtype).cuda().contiguous() + padded_k = ((k + alignment - 1) // alignment) * alignment + + # Quantize it on device. + a_fp8, a_sf = torch.ops.trtllm.mxfp8_quantize(a, is_sf_swizzled_layout, + alignment) + assert a_fp8.shape[1] == padded_k + + # Dequantize it on host. + a_pt = torch.ops.tensorrt_llm.dequantize_mxe4m3_host( + a_fp8.cpu().view(torch.uint8), + a_sf.cpu().view(torch.uint8), is_sf_swizzled_layout) + + # Check if the bits of paddings are zero. + paddings = a_fp8.view(torch.int8)[:, k:] + assert torch.all(paddings == 0), "Paddings should be zero" + + torch.cuda.synchronize() + + mxfp8_quantize_check_accuracy(a_pt[:, :k].cpu().to(torch.float32), + a.cpu().to(torch.float32), 8, 0, 0.999) diff --git a/tests/unittest/_torch/thop/test_fused_qk_norm_rope.py b/tests/unittest/_torch/thop/test_fused_qk_norm_rope.py index ad76e9705e..fd9a924915 100644 --- a/tests/unittest/_torch/thop/test_fused_qk_norm_rope.py +++ b/tests/unittest/_torch/thop/test_fused_qk_norm_rope.py @@ -142,10 +142,12 @@ def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox, eps = 1e-5 base = 10000.0 + factor, low, high, attention_factor = 1.0, 0, 0, 1.0 # Run the custom fusedQKNormRope operation torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v, head_dim, eps, q_weight, - k_weight, base, is_neox, position_ids) + k_weight, base, is_neox, position_ids, + factor, low, high, attention_factor) output = qkv # This op is inplace # Compute reference output using TensorRT-LLM modules diff --git a/tests/unittest/_torch/thop/test_moe.py b/tests/unittest/_torch/thop/test_moe.py index f9dc149c17..e8f1b68903 100644 --- a/tests/unittest/_torch/thop/test_moe.py +++ b/tests/unittest/_torch/thop/test_moe.py @@ -22,6 +22,8 @@ import torch import torch.nn.functional as F sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from enum import Enum + from utils.util import getSMVersion from tensorrt_llm._torch.autotuner import autotune @@ -31,14 +33,38 @@ from tensorrt_llm.quantization.utils.fp4_utils import ( reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a) +# Keep this in sync with the ActType in cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h +class ActType(Enum): + SwiGlu = 0 + + class moe_args: - def __init__(self, num_tokens, num_experts, hidden_size, intermediate_size, - top_k, padding, hidden_states, hidden_states_scale, - hidden_states_scale_global, expert_logits, gemm1_weights, - gemm1_scales, gemm1_scales_global, gemm2_weights, gemm2_scales, - gemm2_scales_global, permute_info, - use_routing_scales_on_input): + def __init__(self, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states, + hidden_states_scale, + hidden_states_scale_global, + expert_logits, + gemm1_weights, + gemm1_scales, + gemm1_scales_global, + gemm2_weights, + gemm2_scales, + gemm2_scales_global, + permute_info, + use_routing_scales_on_input, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_bias=None, + act_type=ActType.SwiGlu): self.num_tokens = num_tokens self.num_experts = num_experts self.hidden_size = hidden_size @@ -57,13 +83,35 @@ class moe_args: self.gemm2_scales_global = gemm2_scales_global self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input + self.gemm1_bias = gemm1_bias + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.gemm2_bias = gemm2_bias + self.act_type = act_type class moe_args_dequant: - def __init__(self, num_tokens, num_experts, hidden_size, intermediate_size, - top_k, padding, hidden_states, expert_logits, gemm1_weights, - gemm2_weights, permute_info, use_routing_scales_on_input): + def __init__(self, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states, + expert_logits, + gemm1_weights, + gemm2_weights, + permute_info, + use_routing_scales_on_input, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_bias=None, + act_type=ActType.SwiGlu): self.num_tokens = num_tokens self.num_experts = num_experts self.hidden_size = hidden_size @@ -76,6 +124,12 @@ class moe_args_dequant: self.gemm2_weights = gemm2_weights self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input + self.gemm1_bias = gemm1_bias + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.gemm2_bias = gemm2_bias + self.act_type = act_type def routing_reference(expertLogits, topK, padding): @@ -276,7 +330,10 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): return output -def run_moe_dequant(args, quant_mode=["fp4", "dsFp8", "perTensorFp8"]): +def run_moe_dequant(args, + quant_mode=[ + "fp4", "dsFp8", "perTensorFp8", "mxe4m3", "bf16" + ]): # Permute total_num_padded_tokens = args.permute_info["permutedBufferSize"] expanded_idx_to_permuted_idx = args.permute_info[ @@ -332,9 +389,26 @@ def run_moe_dequant(args, quant_mode=["fp4", "dsFp8", "perTensorFp8"]): if my_num_tokens == 0: continue my_a = gemm1_output[i:i + my_num_tokens] + if args.gemm1_bias is not None: + my_a += args.gemm1_bias[expert_idx] my_x1 = my_a[:, :args.intermediate_size] my_x2 = my_a[:, args.intermediate_size:] - activation_output[i:i + my_num_tokens] = F.silu(my_x2) * my_x1 + if args.act_type == ActType.SwiGlu: + alpha = args.gemm1_alpha[ + expert_idx] if args.gemm1_alpha is not None else 1.0 + beta = args.gemm1_beta[ + expert_idx] if args.gemm1_beta is not None else 0.0 + + clamp_limit = float('inf') + if args.gemm1_clamp_limit is not None: + clamp_limit = args.gemm1_clamp_limit[expert_idx] + # Clamp my_x2 (x_glu) to max=clamp_limit + my_x2 = my_x2.clamp(max=clamp_limit) + # Clamp my_x1 (x_linear) to min=-clamp_limit, max=clamp_limit + my_x1 = my_x1.clamp(min=-clamp_limit, max=clamp_limit) + + act = my_x2 * F.sigmoid(my_x2 * alpha) + activation_output[i:i + my_num_tokens] = act * (beta + my_x1) i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding @@ -348,6 +422,14 @@ def run_moe_dequant(args, quant_mode=["fp4", "dsFp8", "perTensorFp8"]): activation_output.to(torch.bfloat16)) activation_output = activation_output.to(torch.float) args.c_global_sf = c_global_sf + elif quant_mode == "mxe4m3": + activation_output = quant_dequant_mxe4m3( + activation_output.to(torch.float32)) + activation_output = activation_output.to(torch.float) + elif quant_mode == "bf16": + activation_output = activation_output.to(torch.bfloat16).to(torch.float) + elif quant_mode != "dsFp8": + raise ValueError("Invalid quant mode") # Gemm2 gemm2_output = torch.full((total_num_padded_tokens, args.hidden_size), @@ -361,6 +443,8 @@ def run_moe_dequant(args, quant_mode=["fp4", "dsFp8", "perTensorFp8"]): my_a = activation_output[i:i + my_num_tokens] my_b = args.gemm2_weights[expert_idx] my_c = my_a @ my_b.t() + if args.gemm2_bias is not None: + my_c += args.gemm2_bias[expert_idx] gemm2_output[i:i + my_num_tokens] = my_c i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding @@ -390,7 +474,8 @@ def e2m1_and_ufp8_scale_to_float_tensor_v2(e2m1_tensor: torch.Tensor, is_sf_swizzled_layout: bool = True): float_tensor = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2( e2m1_tensor.cpu(), - ufp8_scale_tensor.cpu().reshape(-1), global_scale_tensor.cpu(), + ufp8_scale_tensor.cpu().reshape(-1), + global_scale_tensor.cpu() if global_scale_tensor is not None else None, sf_vec_size, ufp8_type, is_sf_swizzled_layout) return float_tensor @@ -417,6 +502,37 @@ def e2m1_and_ufp8_scale_batches(mat_fp4: torch.Tensor, return result +def dequantize_mxe4m3(e4m3_tensor: torch.Tensor, + ue8m0_scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True): + float_tensor = torch.ops.tensorrt_llm.dequantize_mxe4m3_host( + e4m3_tensor.view(torch.uint8).cpu(), + ue8m0_scale_tensor.view(torch.uint8).cpu(), is_sf_swizzled_layout) + return float_tensor + + +def mxe2m1_and_ue8m0_scale_batches(mat_fp4: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True): + num_batches = mat_fp4.size(0) + ufp8_type = 0 # UE8M0 + sf_vec_size = 32 + + scale_tensor = scale_tensor.view(num_batches, -1) + + tensors = [ + e2m1_and_ufp8_scale_to_float_tensor_v2(mat_fp4[b, :, :], + scale_tensor[b, :], None, + sf_vec_size, ufp8_type, + is_sf_swizzled_layout) + for b in range(num_batches) + ] + + result = torch.stack(tensors) + + return result + + def run_moe_reference_fp4(args): sf_vec_size = 16 @@ -489,6 +605,66 @@ def run_moe_reference_per_tensor_scale_fp8(args): return run_moe_dequant(args_dequant, "perTensorFp8"), args_dequant +def run_moe_reference_mxe4m3_mxe2m1(args): + hidden_states_dequant = dequantize_mxe4m3(args.hidden_states, + args.hidden_states_scale).cuda() + + gemm1_weights_dequant = mxe2m1_and_ue8m0_scale_batches( + args.gemm1_weights, args.gemm1_scales).cuda() + + gemm2_weights_dequant = mxe2m1_and_ue8m0_scale_batches( + args.gemm2_weights, args.gemm2_scales).cuda() + + args_dequant = moe_args_dequant( + args.num_tokens, args.num_experts, args.hidden_size, + args.intermediate_size, args.top_k, args.padding, hidden_states_dequant, + args.expert_logits, gemm1_weights_dequant, gemm2_weights_dequant, + args.permute_info, args.use_routing_scales_on_input, args.gemm1_bias, + args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, + args.gemm2_bias, args.act_type) + + return run_moe_dequant(args_dequant, "mxe4m3"), args_dequant + + +def run_moe_reference_e4m3_mxe2m1(args): + hidden_states_dequant = args.hidden_states.to( + torch.float) / args.hidden_states_scale_global + + gemm1_weights_dequant = mxe2m1_and_ue8m0_scale_batches( + args.gemm1_weights, args.gemm1_scales).cuda() + + gemm2_weights_dequant = mxe2m1_and_ue8m0_scale_batches( + args.gemm2_weights, args.gemm2_scales).cuda() + + args_dequant = moe_args_dequant( + args.num_tokens, args.num_experts, args.hidden_size, + args.intermediate_size, args.top_k, args.padding, hidden_states_dequant, + args.expert_logits, gemm1_weights_dequant, gemm2_weights_dequant, + args.permute_info, args.use_routing_scales_on_input, args.gemm1_bias, + args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, + args.gemm2_bias, args.act_type) + + return run_moe_dequant(args_dequant, "perTensorFp8"), args_dequant + + +def run_moe_reference_bf16_mxe2m1(args): + gemm1_weights_dequant = mxe2m1_and_ue8m0_scale_batches( + args.gemm1_weights, args.gemm1_scales).cuda() + + gemm2_weights_dequant = mxe2m1_and_ue8m0_scale_batches( + args.gemm2_weights, args.gemm2_scales).cuda() + + args_dequant = moe_args_dequant( + args.num_tokens, args.num_experts, args.hidden_size, + args.intermediate_size, args.top_k, args.padding, args.hidden_states, + args.expert_logits, gemm1_weights_dequant, gemm2_weights_dequant, + args.permute_info, args.use_routing_scales_on_input, args.gemm1_bias, + args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, + args.gemm2_bias, args.act_type) + + return run_moe_dequant(args_dequant, "bf16"), args_dequant + + def quant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() sf_vec_size = 16 @@ -567,6 +743,70 @@ def quant_dequant_per_tensor_fp8(a): return a_pt.cuda(), a_global_sf +def quant_mxe4m3(a, is_sf_swizzled_layout=True): + a_fp4, a_sf = torch.ops.tensorrt_llm.quantize_mxe4m3_host( + a.cpu(), is_sf_swizzled_layout) + + return a_fp4, a_sf + + +def quant_mxe2m1(a, is_sf_swizzled_layout=True): + sf_vec_size = 32 + use_ue8m0 = True + + a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a.cuda(), None, sf_vec_size, + use_ue8m0, + is_sf_swizzled_layout) + + return a_fp4, a_sf + + +def quant_mxe2m1_batches(a, num_experts, is_sf_swizzled_layout=True): + quant_a = [] + sfs = [] + for i in range(num_experts): + a_fp4, a_sf = quant_mxe2m1(a[i], is_sf_swizzled_layout) + quant_a.append(a_fp4) + sfs.append(a_sf) + + result_quant_a = torch.stack(quant_a) + result_sfs = torch.stack(sfs) + + return result_quant_a, result_sfs + + +def quant_dequant_mxe4m3(a, is_sf_swizzled_layout=True): + a_fp8, a_sf = torch.ops.tensorrt_llm.quantize_mxe4m3_host( + a.cpu(), is_sf_swizzled_layout) + + a_pt = dequantize_mxe4m3(a_fp8.cpu(), a_sf.cpu()) + + return a_pt.cuda() + + +def check_accuracy(a, b, atol, rtol, percent): + if torch.any(torch.isnan(a)): + raise Exception("NaN in a") + if torch.any(torch.isnan(b)): + raise Exception("NaN in b") + assert a.shape == b.shape + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception("Mismatch percentage is %f for rtol %f" % + (mismatch_percent, rtol)) + + +def are_groups_valid(top_k_groups, n_groups): + if top_k_groups is None or n_groups is None: + return False + if top_k_groups == 0 or n_groups == 0: + return False + return True + + @pytest.mark.skipif( getSMVersion() < 100 or getSMVersion() >= 110, reason="The kernel only supports Blackwell. Current SM is %d." % @@ -624,11 +864,13 @@ class TestMoeFP8: assert top_k <= num_experts assert top_k <= 8 - assert top_k_groups <= 4 - assert num_experts > n_groups - assert num_experts % n_groups == 0 assert num_experts % 4 == 0 - assert top_k < (top_k_groups * num_experts / n_groups) + + if are_groups_valid(top_k_groups, n_groups): + assert top_k_groups <= 4 + assert num_experts > n_groups + assert num_experts % n_groups == 0 + assert top_k < (top_k_groups * num_experts / n_groups) expert_logits = torch.randn((num_tokens, num_experts), device='cuda').to(torch.float) @@ -677,23 +919,6 @@ class TestMoeFP8: # output_dequant_reference, _ = run_moe_reference_dsfp8(args) - # - # Check the results - # - def check_accuracy(a, b, atol, rtol, percent): - if torch.any(torch.isnan(a)): - raise Exception("NaN in a") - if torch.any(torch.isnan(b)): - raise Exception("NaN in b") - assert a.shape == b.shape - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() - if mismatch_percent > 1 - percent: - raise Exception("Mismatch percentage is %f for rtol %f" % - (mismatch_percent, rtol)) - check_accuracy(output_dequant_reference, output_dequant_actual, atol=0.1, @@ -755,6 +980,18 @@ class TestMoeFp4: "routing_method_type": RoutingMethodType.Renormalize }, id="RoutingRenormalize"), + pytest.param( + { + "num_experts": 128, + "top_k": 4, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize + }, + id="RoutingRenormalize_topk_4"), pytest.param( { "num_experts": 128, @@ -835,11 +1072,12 @@ class TestMoeFp4: assert top_k <= num_experts assert top_k <= 8 - if (top_k_groups is not None) and (n_groups is not None): + assert num_experts % 4 == 0 + + if are_groups_valid(top_k_groups, n_groups): assert top_k_groups <= 4 assert num_experts > n_groups assert num_experts % n_groups == 0 - assert num_experts % 4 == 0 assert top_k < (top_k_groups * num_experts / n_groups) if routing_method_type == RoutingMethodType.DeepSeekV3: @@ -1036,23 +1274,6 @@ class TestMoeFp4: output_dequant_actual = output[0].to(torch.float) - # - # Check the results - # - def check_accuracy(a, b, atol, rtol, percent): - if torch.any(torch.isnan(a)): - raise Exception("NaN in a") - if torch.any(torch.isnan(b)): - raise Exception("NaN in b") - assert a.shape == b.shape - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() - if mismatch_percent > 1 - percent: - raise Exception("Mismatch percentage is %f for rtol %f" % - (mismatch_percent, rtol)) - check_accuracy(output_dequant_reference, output_dequant_actual, atol=0.1, @@ -1065,7 +1286,7 @@ class TestMoeFp4: reason="The kernel only supports Blackwell. Current SM is %d." % getSMVersion(), ) -@pytest.mark.parametrize("num_tokens", [1, 2, 16, 64, 1024, 4096]) +@pytest.mark.parametrize("num_tokens", [1, 4, 384, 4096]) @pytest.mark.parametrize("expert_info", [(128, 0, 0, 1, True)]) @pytest.mark.parametrize("hidden_size", [2048]) @pytest.mark.parametrize("intermediate_size", [2048]) @@ -1083,11 +1304,13 @@ def test_moe_fp8_per_tensor_scale(num_tokens, expert_info, hidden_size, assert top_k <= num_experts assert top_k <= 8 - assert top_k_groups <= 4 - assert num_experts > n_groups - assert n_groups == 0 or num_experts % n_groups == 0 assert num_experts % 4 == 0 - assert n_groups == 0 or top_k < (top_k_groups * num_experts / n_groups) + + if are_groups_valid(top_k_groups, n_groups): + assert top_k_groups <= 4 + assert num_experts > n_groups + assert num_experts % n_groups == 0 + assert top_k < (top_k_groups * num_experts / n_groups) expert_logits = torch.randn((num_tokens, num_experts), device='cuda').to(torch.float) @@ -1183,25 +1406,383 @@ def test_moe_fp8_per_tensor_scale(num_tokens, expert_info, hidden_size, output_dequant_actual = output.to(torch.float) - # - # Check the results - # - def check_accuracy(a, b, atol, rtol, percent): - if torch.any(torch.isnan(a)): - raise Exception("NaN in a") - if torch.any(torch.isnan(b)): - raise Exception("NaN in b") - assert a.shape == b.shape - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() - if mismatch_percent > 1 - percent: - raise Exception("Mismatch percentage is %f for rtol %f" % - (mismatch_percent, rtol)) - check_accuracy(output_dequant_reference, output_dequant_actual, atol=0.1, rtol=0.85, percent=0.925) + + +@pytest.mark.skipif( + getSMVersion() != 100, + reason="The kernel only supports Blackwell. Current SM is %d." % + getSMVersion(), +) +@pytest.mark.parametrize("num_tokens", [1, 256, 1024]) +@pytest.mark.parametrize("hidden_size", [512]) +@pytest.mark.parametrize("intermediate_size", [512]) +@pytest.mark.parametrize( + "routing_info", + [ + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize + }, + id="RoutingRenormalize_topk_8"), + pytest.param( + { + "num_experts": 128, + "top_k": 4, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize + }, + id="RoutingRenormalize_topk_4"), + ], +) +@pytest.mark.parametrize("dtype_activation", ["mxfp8", "bf16", "fp8"]) +@pytest.mark.parametrize("act_type_str", ["SwiGlu", "GatedSilu"]) +@pytest.mark.parametrize("use_autotune", [True, False], + ids=["autotune", "no_autotune"]) +def test_moe_mxe2m1_weights(num_tokens, hidden_size, intermediate_size, + routing_info, dtype_activation, act_type_str, + use_autotune): + torch.random.manual_seed(0) + + # + # Data Generation + # + + act_type = ActType.SwiGlu + num_experts = routing_info["num_experts"] + top_k = routing_info["top_k"] + n_groups = routing_info["n_groups"] + top_k_groups = routing_info["top_k_groups"] + routed_scaling = routing_info["routed_scaling"] + has_routing_bias = routing_info["has_routing_bias"] + routing_method_type = routing_info["routing_method_type"] + # Perfect expert distribution results in `num_tokens * top_k / num_experts` tokens per expert. + tile_tokens_dim = (num_tokens * top_k) // num_experts + # And pad the number of tokens to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(tile_tokens_dim) + # At least padded to 8 tokens per CTA tile + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + padding = tile_tokens_dim + if padding >= 256: + pytest.skip("Routing kernel requires that padding be less than 256") + + if act_type_str == "GatedSilu": + if not use_autotune or dtype_activation != "mxfp8": + pytest.skip("GatedSilu is tested only with autotune and mxfp8") + if not use_autotune: + if dtype_activation != "mxfp8": + pytest.skip("No autotune is tested only with mxfp8") + if top_k != 8: + if dtype_activation != "mxfp8": + pytest.skip("TopK = 4 is tested only with mxfp8") + if num_tokens == 1024: + if top_k != 4 or dtype_activation != "mxfp8" or act_type_str != "SwiGlu" or not use_autotune: + pytest.skip( + "1024 tokens is tested only with topk=4, mxfp8, SwiGlu, and autotune" + ) + + assert top_k <= num_experts + assert top_k <= 8 + assert hidden_size % 128 == 0 + assert intermediate_size % 128 == 0 + + if are_groups_valid(top_k_groups, n_groups): + assert top_k_groups == 4 + assert num_experts > n_groups + assert num_experts % n_groups == 0 + assert top_k < (top_k_groups * num_experts / n_groups) + + if routing_method_type == RoutingMethodType.Renormalize or routing_method_type == RoutingMethodType.RenormalizeNaive: + expert_logits = torch.randn((num_tokens, num_experts), + device='cuda').to(torch.bfloat16) + else: + raise ValueError("Invalid routing method type") + + if has_routing_bias: + routing_bias = torch.randn(num_experts, + device="cuda", + dtype=torch.bfloat16) + else: + routing_bias = None + + hidden_states = 2 * torch.randn( + (num_tokens, hidden_size), device='cuda', dtype=torch.float32) + gemm1_weights = torch.randn( + (num_experts, 2 * intermediate_size, hidden_size), + device='cuda', + dtype=torch.bfloat16) + gemm2_weights = torch.randn((num_experts, hidden_size, intermediate_size), + device='cuda', + dtype=torch.bfloat16) + gemm1_bias = 50 * torch.randn( + num_experts, 2 * intermediate_size, device='cuda', dtype=torch.float) + gemm1_alpha = None + gemm1_beta = None + if act_type_str == "SwiGlu": + gemm1_alpha = torch.randn(num_experts, device='cuda', dtype=torch.float) + gemm1_beta = torch.randn(num_experts, device='cuda', dtype=torch.float) + + gemm1_clamp_limit = torch.full((num_experts, ), + 7.0, + device='cuda', + dtype=torch.float) + gemm2_bias = 50 * torch.randn( + num_experts, hidden_size, device='cuda', dtype=torch.float) + + hidden_states_mxe4m3 = None + hidden_states_scale_linear_mxe4m3 = None + hidden_states_scale_mxe4m3_bytes = None + if dtype_activation == "mxfp8": + # Quantize hidden states. Produces scales for activations in 128x4 layout for ref impl. + hidden_states_mxe4m3, hidden_states_scale_mxe4m3_bytes = quant_mxe4m3( + hidden_states, True) + # We do it twice to get the linear layout for scales for the FP4 kernels. + _, hidden_states_scale_linear_mxe4m3_bytes = quant_mxe4m3( + hidden_states, False) + + hidden_states_scale_linear_mxe4m3 = hidden_states_scale_linear_mxe4m3_bytes.view( + torch.uint8) # ue8m0 scaling factors + + sf_block_size = 32 + # Quantize the weights for FC1. Produces scales for weights in 128x4 layout for ref impl. + gemm1_weights_mxe2m1_bytes, gemm1_scales_mxe2m1_bytes = quant_mxe2m1_batches( + gemm1_weights, num_experts, True) + # We do it twice to get the linear layout for scales for the FP4 kernels. + _, gemm1_scales_linear_mxe2m1_bytes = quant_mxe2m1_batches( + gemm1_weights, num_experts, False) + + gemm1_weights_mxe2m1 = gemm1_weights_mxe2m1_bytes.view(torch.uint8).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp8 + gemm1_scales_linear_mxe2m1 = gemm1_scales_linear_mxe2m1_bytes.view( + torch.uint8).reshape(num_experts, 2 * intermediate_size, hidden_size // + sf_block_size) # ue8m0 scaling factors + + # Quantize the weights for FC2. Produces scales for weights in 128x4 layout for ref impl. + gemm2_weights_mxe2m1_bytes, gemm2_scales_mxe2m1_bytes = quant_mxe2m1_batches( + gemm2_weights, num_experts, True) + # We do it twice to get the linear layout for scales for the FP4 kernels. + _, gemm2_scales_linear_mxe2m1_bytes = quant_mxe2m1_batches( + gemm2_weights, num_experts, False) + + gemm2_weights_mxe2m1 = gemm2_weights_mxe2m1_bytes.view(torch.uint8).reshape( + num_experts, hidden_size, intermediate_size // 2) # packed mxe2m1 + gemm2_scales_linear_mxe2m1 = gemm2_scales_linear_mxe2m1_bytes.view( + torch.uint8).reshape(num_experts, hidden_size, intermediate_size // + sf_block_size) # ue8m0 scaling factors + if routing_method_type == RoutingMethodType.Renormalize: + permute_info, scores = routing_reference_renormalize( + expert_logits, top_k, num_experts, padding) + elif routing_method_type == RoutingMethodType.RenormalizeNaive: + permute_info, scores = routing_reference_renormalize_naive( + expert_logits, top_k, num_experts, padding) + else: + raise ValueError("Invalid routing method type") + + input_hidden_states = None + input_hidden_global_scale = None + if dtype_activation == "mxfp8": + input_hidden_states = hidden_states_mxe4m3 + elif dtype_activation == "bf16": + input_hidden_states = hidden_states.to(torch.bfloat16) + elif dtype_activation == "fp8": + input_hidden_states, input_hidden_global_scale = quant_fp8_per_tensor( + hidden_states) + + args = moe_args( + num_tokens, num_experts, hidden_size, intermediate_size, top_k, padding, + input_hidden_states, hidden_states_scale_mxe4m3_bytes + if dtype_activation == "mxfp8" else None, input_hidden_global_scale, + scores, gemm1_weights_mxe2m1_bytes, gemm1_scales_mxe2m1_bytes, None, + gemm2_weights_mxe2m1_bytes, gemm2_scales_mxe2m1_bytes, None, + permute_info, False, gemm1_bias, gemm1_alpha, gemm1_beta, + gemm1_clamp_limit, gemm2_bias, act_type) + + if dtype_activation == "mxfp8": + output_dequant_reference, args_dequant = run_moe_reference_mxe4m3_mxe2m1( + args) + elif dtype_activation == "bf16": + output_dequant_reference, args_dequant = run_moe_reference_bf16_mxe2m1( + args) + elif dtype_activation == "fp8": + output_dequant_reference, args_dequant = run_moe_reference_e4m3_mxe2m1( + args) + else: + raise ValueError("Invalid dtype_activation") + # + # Run the reference implementations + # + # It is important to run the reference implementation before the TRT-LLM kernel + # because the MoE shuffles the weights in-place. + # output_dequant_reference, args_dequant = run_moe_reference_mxe4m3_mxe2m1(args) + + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Reorder rows of W1 and scales for fused gated activation + gemm1_weights_mxe2m1_interleaved = [] + gemm1_scales_mxe2m1_interleaved = [] + gemm1_bias_interleaved = [] + for i in range(num_experts): + gemm1_weights_mxe2m1_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_weights_mxe2m1[i].clone())) + gemm1_scales_mxe2m1_interleaved.append( + reorder_rows_for_gated_act_gemm( + gemm1_scales_linear_mxe2m1[i].clone())) + gemm1_bias_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_bias[i].clone().reshape( + -1, 1))) + + # Stack weights and scales for all experts + gemm1_weights_mxe2m1_interleaved = torch.stack( + gemm1_weights_mxe2m1_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size // 2) + gemm1_scales_mxe2m1_interleaved = torch.stack( + gemm1_scales_mxe2m1_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size // sf_block_size) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxe2m1_shuffled = [] + gemm1_scales_mxe2m1_shuffled = [] + gemm2_weights_mxe2m1_shuffled = [] + gemm2_scales_mxe2m1_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + for i in range(num_experts): + gemm1_weights_mxe2m1_shuffled.append( + shuffle_matrix_a( + gemm1_weights_mxe2m1_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_mxe2m1_shuffled.append( + shuffle_matrix_sf_a( + gemm1_scales_mxe2m1_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + + gemm1_bias_shuffled.append( + shuffle_matrix_a(gemm1_bias_interleaved[i], epilogue_tile_m)) + + gemm2_weights_mxe2m1_shuffled.append( + shuffle_matrix_a(gemm2_weights_mxe2m1[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_mxe2m1_shuffled.append( + shuffle_matrix_sf_a(gemm2_scales_linear_mxe2m1[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_bias_shuffled.append( + shuffle_matrix_a(gemm2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + # Stack weights for all experts + gemm1_weights_mxe2m1_shuffled = torch.stack(gemm1_weights_mxe2m1_shuffled) + gemm1_scales_mxe2m1_shuffled = torch.stack( + gemm1_scales_mxe2m1_shuffled).view(torch.uint8).reshape( + num_experts, 2 * intermediate_size, hidden_size // sf_block_size) + + gemm2_weights_mxe2m1_shuffled = torch.stack(gemm2_weights_mxe2m1_shuffled) + gemm2_scales_mxe2m1_shuffled = torch.stack( + gemm2_scales_mxe2m1_shuffled).view(torch.uint8).reshape( + num_experts, hidden_size, intermediate_size // sf_block_size) + + gemm1_bias_shuffled = torch.stack(gemm1_bias_shuffled).reshape( + num_experts, -1) + gemm2_bias_shuffled = torch.stack(gemm2_bias_shuffled).reshape( + num_experts, -1) + + if dtype_activation == "fp8": + # c_global_sf: fc2_input_scale + scale_c_fc1 = ( + args_dequant.c_global_sf * + (1.0 / args.hidden_states_scale_global)).expand(num_experts).to( + torch.float).cuda().contiguous() + + # self.fc31_alpha + scale_gate_fc1 = ( + 1.0 / args.hidden_states_scale_global).expand(num_experts).to( + torch.float).cuda().contiguous() + + # self.fc2_alpha + scale_c_fc2 = (1.0 / args_dequant.c_global_sf).expand(num_experts).to( + torch.float).cuda().contiguous() + + # NOTE: correct the beta and clamp to account for the global scale factor + # Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h + # for more details + gemm1_beta = gemm1_beta * args.hidden_states_scale_global + gemm1_clamp_limit = gemm1_clamp_limit * args.hidden_states_scale_global + # Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h + # for more details + gemm1_bias_shuffled = gemm1_bias_shuffled * args.hidden_states_scale_global + gemm2_bias_shuffled = gemm2_bias_shuffled * args_dequant.c_global_sf + + # + # Run the TRT-LLM kernel + # + unpadded_hidden_size = hidden_size + with autotune(use_autotune): + if dtype_activation == "mxfp8": + # Test fused unpadding by checking only half of the output. + unpadded_hidden_size = hidden_size // 2 + output = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner( + expert_logits, routing_bias, + hidden_states_mxe4m3.cuda().view(torch.float8_e4m3fn), + hidden_states_scale_linear_mxe4m3.cuda(), + gemm1_weights_mxe2m1_shuffled.cuda(), + gemm1_scales_mxe2m1_shuffled.cuda(), gemm1_bias_shuffled.cuda(), + gemm1_alpha, gemm1_beta, gemm1_clamp_limit, + gemm2_weights_mxe2m1_shuffled.cuda(), + gemm2_scales_mxe2m1_shuffled.cuda(), gemm2_bias_shuffled.cuda(), + num_experts, top_k, n_groups, top_k_groups, intermediate_size, + unpadded_hidden_size, 0, num_experts, routed_scaling, + tile_tokens_dim, routing_method_type, act_type.value) + elif dtype_activation == "bf16": + output = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner( + expert_logits, routing_bias, + hidden_states.cuda().to(torch.bfloat16), + gemm1_weights_mxe2m1_shuffled.cuda(), + gemm1_scales_mxe2m1_shuffled.cuda(), gemm1_bias_shuffled.cuda(), + gemm1_alpha, gemm1_beta, gemm1_clamp_limit, + gemm2_weights_mxe2m1_shuffled.cuda(), + gemm2_scales_mxe2m1_shuffled.cuda(), gemm2_bias_shuffled.cuda(), + num_experts, top_k, n_groups, top_k_groups, intermediate_size, + 0, num_experts, routed_scaling, tile_tokens_dim, + routing_method_type, act_type.value) + elif dtype_activation == "fp8": + output = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner( + expert_logits, routing_bias, + input_hidden_states.cuda().to(torch.float8_e4m3fn), + gemm1_weights_mxe2m1_shuffled.cuda(), + gemm1_scales_mxe2m1_shuffled.cuda(), gemm1_bias_shuffled.cuda(), + gemm1_alpha, gemm1_beta, gemm1_clamp_limit, + gemm2_weights_mxe2m1_shuffled.cuda(), + gemm2_scales_mxe2m1_shuffled.cuda(), gemm2_bias_shuffled.cuda(), + scale_c_fc1, scale_gate_fc1, scale_c_fc2, num_experts, top_k, + n_groups, top_k_groups, intermediate_size, 0, num_experts, + routed_scaling, tile_tokens_dim, routing_method_type, + act_type.value) + else: + raise ValueError("Invalid dtype_activation") + + output_dequant_actual = output.to(torch.float) + output_dequant_reference = output_dequant_reference[:, : + unpadded_hidden_size].contiguous( + ) + percent = 0.8 if dtype_activation == "mxfp8" else 0.85 + check_accuracy(output_dequant_reference, + output_dequant_actual, + atol=0.0, + rtol=0.10, + percent=percent) diff --git a/tests/unittest/_torch/thop/test_scaled_mm.py b/tests/unittest/_torch/thop/test_scaled_mm.py index f3cb8ea856..40dfa4e0cb 100644 --- a/tests/unittest/_torch/thop/test_scaled_mm.py +++ b/tests/unittest/_torch/thop/test_scaled_mm.py @@ -69,7 +69,10 @@ def test_fp8_scaled_mm(output_dtype, m, k_n): use_fast_accum=True, ) os.environ["CUBLASLT_WORKSPACE_SIZE"] = old_env - np.testing.assert_allclose(ref.float().cpu(), output.float().cpu()) + np.testing.assert_allclose(ref.float().cpu(), + output.float().cpu(), + atol=0.01, + rtol=0.01) if getSMVersion() == 90: cutlass_output = torch.ops.trtllm.cutlass_scaled_mm( @@ -83,7 +86,9 @@ def test_fp8_scaled_mm(output_dtype, m, k_n): # TODO(zhenhuan): cutlass kernel has acc issue on some shapes try: np.testing.assert_allclose(ref.float().cpu(), - cutlass_output.float().cpu()) + cutlass_output.float().cpu(), + atol=1, + rtol=0.01) except Exception as e: warn(RuntimeWarning("cutlass result is not correct: " + repr(e))) diff --git a/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py b/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py index 09fb1a8542..aafccd9774 100644 --- a/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py +++ b/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py @@ -70,8 +70,8 @@ class TestFunctional(unittest.TestCase): mat_b[0][0] = 36 mat_b_ref[0][0] = 2.0 - a_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(a_block_sf) - b_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(b_block_sf) + a_block_sf = torch.ops.trtllm.block_scale_interleave(a_block_sf) + b_block_sf = torch.ops.trtllm.block_scale_interleave(b_block_sf) c = (torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(mat_a, mat_b, a_block_sf, b_block_sf, a_sf, diff --git a/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py b/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py index fab60be84b..447c807a8c 100644 --- a/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py +++ b/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py @@ -15,7 +15,7 @@ import pytest import torch -from _torch.helpers import calc_diff +from _torch.helpers import calc_diff, calc_woq_tolerence def weight_only_quant_gemm_reference(a, b, b_scales): @@ -29,18 +29,6 @@ def weight_only_quant_gemm_reference(a, b, b_scales): return ref.to(dtype=a_dtype) -def woq_tolerence_calculate(output, output_ref, b_dtype): - if b_dtype == torch.int8: - bits_in_type = 8 - elif b_dtype == torch.quint4x2: - bits_in_type = 4 - quant_range_scale = 1.0 / float(1 << (bits_in_type - 1)) - max_val = torch.max(abs(output_ref)).item() - atol = (max_val * quant_range_scale) * 1.5 # allow for rounding - - return atol - - @pytest.mark.parametrize( "k, n", [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (1024, 1024)], @@ -79,5 +67,5 @@ def test_weight_only_quant_gemm(a_dtype, b_dtype, m, k, n): # check accuracy diff = calc_diff(output, output_ref) assert diff < 1e-3, f"Difference {diff} >= 1e-3" - atol = woq_tolerence_calculate(output, output_ref, b_dtype) + atol = calc_woq_tolerence(output_ref, b_dtype) torch.testing.assert_close(output_ref, output, atol=atol, rtol=1e-7) diff --git a/tests/unittest/api_stability/api_stability_core.py b/tests/unittest/api_stability/api_stability_core.py index 2278fad201..61650d5909 100644 --- a/tests/unittest/api_stability/api_stability_core.py +++ b/tests/unittest/api_stability/api_stability_core.py @@ -27,6 +27,7 @@ from tensorrt_llm.executor.result import TokenLogprobs from tensorrt_llm.llmapi import (CalibConfig, CompletionOutput, GuidedDecodingParams, QuantConfig, RequestOutput, SamplingParams) +from tensorrt_llm.llmapi.llm_args import SamplerType from tensorrt_llm.llmapi.llm_utils import LlmArgs from tensorrt_llm.logger import Singleton diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 984c8953ec..86f740c384 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -27,6 +27,10 @@ methods: annotation: Optional[int] default: null status: prototype + return_perf_metrics: + annotation: bool + default: False + status: prototype # Bindings and mirrored configs peft_cache_config: annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig] @@ -107,10 +111,10 @@ methods: annotation: bool default: False status: beta - enable_trtllm_sampler: - annotation: bool - default: False - status: prototype + sampler_type: + annotation: Union[str, tensorrt_llm.llmapi.llm_args.SamplerType] + default: auto + status: beta enable_iter_perf_stats: annotation: bool default: False @@ -144,7 +148,7 @@ methods: default: False status: prototype allreduce_strategy: - annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL']] + annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC']] default: AUTO status: beta decoding_config: diff --git a/tests/unittest/api_stability/references/quant_config.yaml b/tests/unittest/api_stability/references/quant_config.yaml index dbf1201e49..510b6a0186 100644 --- a/tests/unittest/api_stability/references/quant_config.yaml +++ b/tests/unittest/api_stability/references/quant_config.yaml @@ -16,6 +16,9 @@ methods: kv_cache_quant_algo: annotation: Optional[tensorrt_llm.quantization.mode.QuantAlgo] default: null + mamba_ssm_cache_dtype: + annotation: Optional[str] + default: null pre_quant_scale: annotation: bool default: false diff --git a/tests/unittest/api_stability/references/request_output.yaml b/tests/unittest/api_stability/references/request_output.yaml index 52e499dd14..7e3054cd5e 100644 --- a/tests/unittest/api_stability/references/request_output.yaml +++ b/tests/unittest/api_stability/references/request_output.yaml @@ -11,4 +11,13 @@ methods: clear_logprob_params: parameters: {} return_annotation: None + record_stats: + parameters: + output: + annotation: tensorrt_llm.executor.result.CompletionOutput + default: inspect._empty + stats: + annotation: Optional[dict[str, float]] + default: None + return_annotation: None properties: {} diff --git a/tests/unittest/bindings/test_bindings_ut.py b/tests/unittest/bindings/test_bindings_ut.py index e12fd52cb4..f049a4437c 100644 --- a/tests/unittest/bindings/test_bindings_ut.py +++ b/tests/unittest/bindings/test_bindings_ut.py @@ -26,7 +26,8 @@ def test_quant_mode(): quant_mode = _tb.QuantMode.from_description(True, True, True, True, True, True, True, True, False, False, - False, False, False, False) + False, False, False, False, + False, False) assert quant_mode.has_int4_weights quant_mode -= _tb.QuantMode.int4_weights() assert not quant_mode.has_int4_weights diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 6dcaa0d953..8556cf54d6 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -1314,6 +1314,7 @@ def test_kv_cache_config(): assert config.enable_partial_reuse == True assert config.copy_on_partial_reuse == True assert config.use_uvm == False + assert config.attention_dp_events_gather_period_ms == 5 config.enable_block_reuse = False config.max_tokens = 1 @@ -1328,6 +1329,7 @@ def test_kv_cache_config(): config.enable_partial_reuse = False config.copy_on_partial_reuse = False config.use_uvm = True + config.attention_dp_events_gather_period_ms = 10 assert config.enable_block_reuse == False assert config.max_tokens == 1 assert config.max_attention_window == [2] @@ -1341,6 +1343,7 @@ def test_kv_cache_config(): assert config.enable_partial_reuse == False assert config.copy_on_partial_reuse == False assert config.use_uvm == True + assert config.attention_dp_events_gather_period_ms == 10 kwargs = { "enable_block_reuse": True, @@ -1354,7 +1357,8 @@ def test_kv_cache_config(): "event_buffer_max_size": 2048, "enable_partial_reuse": True, "copy_on_partial_reuse": False, - "use_uvm": True + "use_uvm": True, + "attention_dp_events_gather_period_ms": 10 } config = trtllm.KvCacheConfig(**kwargs) for k, v in kwargs.items(): diff --git a/tests/unittest/llmapi/apps/_test_openai_chat.py b/tests/unittest/llmapi/apps/_test_openai_chat.py index 982d06fc9a..6e58b09478 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat.py @@ -533,10 +533,10 @@ server_with_custom_sampler = make_server_with_custom_sampler_fixture('chat') 'server_with_custom_sampler', [ { - 'use_trtllm_sampler': False + 'sampler_type': "TorchSampler" }, # torch_sampler { - 'use_trtllm_sampler': True + 'sampler_type': "TRTLLMSampler" }, # trtllm_sampler ], indirect=True, diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index b31e7f2b05..d92ca06167 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -31,7 +31,7 @@ def temp_extra_llm_api_options_file(request): }, "build_config": { "max_num_tokens": 16384, - } + }, } with open(temp_file_path, 'w') as f: @@ -46,7 +46,10 @@ def temp_extra_llm_api_options_file(request): @pytest.fixture(scope="module") def server(model_name: str, temp_extra_llm_api_options_file: str): model_path = get_model_path(model_name) - args = ["--extra_llm_api_options", temp_extra_llm_api_options_file] + args = [ + "--extra_llm_api_options", temp_extra_llm_api_options_file, + "--max_batch_size", "64" + ] with RemoteOpenAIServer(model_path, args) as remote_server: yield remote_server diff --git a/tests/unittest/llmapi/apps/_test_openai_completions.py b/tests/unittest/llmapi/apps/_test_openai_completions.py index ce48879535..3e1c96cff3 100644 --- a/tests/unittest/llmapi/apps/_test_openai_completions.py +++ b/tests/unittest/llmapi/apps/_test_openai_completions.py @@ -80,6 +80,18 @@ def test_single_completion(client: openai.OpenAI, model_name): assert len(completion.choices[0].text) >= 1 +def test_single_completion_with_too_long_prompt(client: openai.OpenAI, + model_name): + completion = client.completions.create( + model=model_name, + prompt="Hello, my name is" * 100, + max_tokens=5, + temperature=0.0, + ) + + print(completion) + + @pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("echo", [True, False]) async def test_completion_streaming(async_client: openai.AsyncOpenAI, @@ -383,10 +395,10 @@ server_with_custom_sampler = make_server_with_custom_sampler_fixture( 'server_with_custom_sampler', [ { - 'use_trtllm_sampler': False + 'sampler_type': "TorchSampler" }, # torch_sampler { - 'use_trtllm_sampler': True + 'sampler_type': "TRTLLMSampler" }, # trtllm_sampler ], indirect=True, diff --git a/tests/unittest/llmapi/apps/_test_openai_lora.py b/tests/unittest/llmapi/apps/_test_openai_lora.py index 313304a251..8e62412242 100644 --- a/tests/unittest/llmapi/apps/_test_openai_lora.py +++ b/tests/unittest/llmapi/apps/_test_openai_lora.py @@ -39,7 +39,10 @@ def temp_extra_llm_api_options_file(): "max_lora_rank": 8, "max_loras": 4, "max_cpu_loras": 4, - } + }, + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + "cuda_graph_config": None } with open(temp_file_path, 'w') as f: diff --git a/tests/unittest/llmapi/apps/_test_openai_misc.py b/tests/unittest/llmapi/apps/_test_openai_misc.py index 51e3d4f840..8cc715389f 100644 --- a/tests/unittest/llmapi/apps/_test_openai_misc.py +++ b/tests/unittest/llmapi/apps/_test_openai_misc.py @@ -10,22 +10,30 @@ from ..test_llm import get_model_path from .openai_server import RemoteOpenAIServer -@pytest.fixture(scope="module") -def model_name(): - return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" - - @pytest.fixture(scope="module", params=["trt", "pytorch"]) def backend(request): return request.param +@pytest.fixture(scope="module") +def model_name(backend): + # Note: TRT backend does not support Qwen3-0.6B-Base, + # and PyTorch backend does not support going over the limit of "max_position_embeddings" tokens + # of TinyLlama. + if backend == "trt": + return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + else: + return "Qwen3/Qwen3-0.6B-Base" + + @pytest.fixture(scope="module", params=["8"]) def max_batch_size(request): return request.param -@pytest.fixture(scope="module", params=["80000"]) +# Note: In the model Qwen3-0.6B-Base, "max_position_embeddings" is 32768, +# so the inferred max_seq_len is 32768. +@pytest.fixture(scope="module", params=["32768"]) def max_seq_len(request): return request.param diff --git a/tests/unittest/llmapi/apps/_test_openai_prometheus.py b/tests/unittest/llmapi/apps/_test_openai_prometheus.py new file mode 100644 index 0000000000..8a360668fd --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_prometheus.py @@ -0,0 +1,67 @@ +import logging +import os +import tempfile +from urllib.request import urlopen + +import pytest +import yaml + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"]) +def model_name(): + return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(request): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") + try: + extra_llm_api_options_dict = {"return_perf_metrics": True} + + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, + temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer: + model_path = get_model_path(model_name) + args = ["--backend", "pytorch", "--tp_size", "1"] + args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file]) + logger.info(f"Starting server, model: {model_name}, args: {args}") + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + logger.info("Tests completed, shutting down server") + + +def test_metrics_endpoint(server: RemoteOpenAIServer): + + client = server.get_client() + client.completions.create( + model="Server", + prompt="Hello, my name is", + max_tokens=25, + stream=False, + ) + + response = urlopen(f'{server.url_root}/prometheus/metrics') + assert response.status is 200 + + data = response.read().decode("utf-8") + assert "request_success_total" in data + assert "e2e_request_latency_seconds" in data + assert "time_to_first_token_seconds" in data + assert "request_queue_time_seconds" in data diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_example.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_example.py index 262eafa820..6921c024d5 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_example.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_example.py @@ -1,8 +1,11 @@ +import json import os import subprocess import sys +import tempfile import pytest +import yaml from .openai_server import RemoteOpenAIServer @@ -16,10 +19,26 @@ def model_name(): @pytest.fixture(scope="module") -def server(model_name: str): +def temp_extra_llm_api_options_file(): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") + try: + extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"} + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, temp_extra_llm_api_options_file: str): model_path = get_model_path(model_name) # fix port to facilitate concise trtllm-serve examples - with RemoteOpenAIServer(model_path, port=8000) as remote_server: + args = ["--extra_llm_api_options", temp_extra_llm_api_options_file] + with RemoteOpenAIServer(model_path, args, port=8000) as remote_server: yield remote_server @@ -40,8 +59,19 @@ def test_trtllm_serve_examples(exe: str, script: str, server: RemoteOpenAIServer, example_root: str): client_script = os.path.join(example_root, script) # CalledProcessError will be raised if any errors occur - subprocess.run([exe, client_script], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - check=True) + result = subprocess.run([exe, client_script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True) + if script.startswith("curl"): + # For curl scripts, we expect a JSON response + result_stdout = result.stdout.strip() + try: + data = json.loads(result_stdout) + assert "code" not in data or data[ + "code"] == 200, f"Unexpected response: {data}" + except json.JSONDecodeError as e: + pytest.fail( + f"Failed to parse JSON response from {script}: {e}\nStdout: {result_stdout}\nStderr: {result.stderr}" + ) diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_benchmark.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_benchmark.py new file mode 100644 index 0000000000..502c366561 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_benchmark.py @@ -0,0 +1,124 @@ +import os +import subprocess +import sys +import tempfile + +import pytest +import yaml +from utils.util import skip_gpu_memory_less_than_80gb + +from .openai_server import RemoteOpenAIServer + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from test_llm import get_model_path + + +@pytest.fixture(scope="module", ids=["Qwen2.5-VL-3B-Instruct"]) +def model_name(): + return "Qwen2.5-VL-3B-Instruct" + + +@pytest.fixture(scope="module") +def model_path(model_name: str): + return get_model_path(model_name) + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(request): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join( + temp_dir, "extra_llm_api_options_multimodal_benchmark.yaml") + try: + extra_llm_api_options_dict = { + "kv_cache_config": { + "free_gpu_memory_fraction": 0.6, + }, + "max_num_tokens": 16384, # for pytorch backend + # NOTE: This is for video support. + "build_config": { + "max_num_tokens": 16384, + } + } + + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_path: str, temp_extra_llm_api_options_file: str): + # Use pytorch backend for multimodal support and fix port to facilitate benchmarking + args = ["--extra_llm_api_options", temp_extra_llm_api_options_file] + with RemoteOpenAIServer(model_path, port=8000, + cli_args=args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def benchmark_root(): + llm_root = os.getenv("LLM_ROOT") + return os.path.join(llm_root, "tensorrt_llm", "serve", "scripts") + + +# TODO: Add this to /code/llm-models +def vision_arena_dataset_path(): + """Return a vision arena dataset path for testing.""" + return "lmarena-ai/vision-arena-bench-v0.1" + + +@skip_gpu_memory_less_than_80gb +@pytest.mark.parametrize("dataset_name,dataset_args", [("random_image", { + "--num-images": "1", + "--image-size": "512", +}), ("random_image", { + "--num-images": "2", + "--image-size": "512", +}), ("hf", { + "--dataset-path": vision_arena_dataset_path(), +})], + ids=[ + "random_image-single_image", + "random_image-dual_images", + "hf-vision_arena_dataset" + ]) +def test_trtllm_serve_multimodal_benchmark(server: RemoteOpenAIServer, + benchmark_root: str, model_path: str, + dataset_name: str, + dataset_args: dict): + """Test multimodal benchmark serving with different datasets.""" + client_script = os.path.join(benchmark_root, "benchmark_serving.py") + + # Base command arguments + benchmark_cmd = [ + "python3", + client_script, + "--backend", + "openai-chat", # Required for multimodal + "--dataset-name", + dataset_name, + "--model", + "qwen2.5-vl", + "--tokenizer", + model_path, + "--num-prompts", + "10", # Small number for testing + ] + + # Add dataset-specific arguments + for key, value in dataset_args.items(): + benchmark_cmd.extend([key, str(value)]) + + # CalledProcessError will be raised if any errors occur + result = subprocess.run(benchmark_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True) + + # Basic validation that the benchmark ran successfully + assert result.returncode == 0 + assert "Serving Benchmark Result" in result.stdout diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py index f86f969427..5b28e12675 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py @@ -31,7 +31,7 @@ def temp_extra_llm_api_options_file(request): # NOTE: This is for video support. "build_config": { "max_num_tokens": 16384, - } + }, } with open(temp_file_path, 'w') as f: @@ -46,10 +46,7 @@ def temp_extra_llm_api_options_file(request): @pytest.fixture(scope="module") def server(model_name: str, temp_extra_llm_api_options_file: str): model_path = get_model_path(model_name) - args = [ - "--backend", "pytorch", "--extra_llm_api_options", - temp_extra_llm_api_options_file - ] + args = ["--extra_llm_api_options", temp_extra_llm_api_options_file] with RemoteOpenAIServer(model_path, port=8000, cli_args=args) as remote_server: yield remote_server diff --git a/tests/unittest/llmapi/apps/utils.py b/tests/unittest/llmapi/apps/utils.py index bcfd316714..073760d51f 100644 --- a/tests/unittest/llmapi/apps/utils.py +++ b/tests/unittest/llmapi/apps/utils.py @@ -151,8 +151,7 @@ def make_server_with_custom_sampler_fixture(api_type: str) -> Callable: def server_with_custom_sampler(model_name: str, request: Any, backend: str, tmp_path: Path) -> RemoteOpenAIServer: '''Fixture to launch a server (pytorch backend only) with a custom sampler configuration.''' - use_trtllm_sampler = getattr(request, 'param', - {}).get('use_trtllm_sampler', True) + sampler_type = getattr(request, 'param', {}).get('sampler_type', "auto") if backend != 'pytorch': pytest.skip( f"Server with custom sampler is only supported for pytorch backend, skipping for {backend}" @@ -162,7 +161,7 @@ def make_server_with_custom_sampler_fixture(api_type: str) -> Callable: temp_file_path = tmp_path / f'test_sampler_config_{request.node.name}.yaml' extra_llm_api_options_dict = { 'enable_chunked_prefill': True, - 'enable_trtllm_sampler': use_trtllm_sampler + 'sampler_type': sampler_type } with temp_file_path.open('w') as f: yaml.dump(extra_llm_api_options_dict, f) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 2b7c606bf4..4f7488205a 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -42,7 +42,7 @@ from tensorrt_llm.llmapi.llm_utils import (BuildConfig, QuantAlgo, QuantConfig, from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer, load_hf_tokenizer) from tensorrt_llm.llmapi.utils import get_total_gpu_memory -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode from tensorrt_llm.sampling_params import (BatchedLogitsProcessor, @@ -1459,20 +1459,7 @@ def llama_v2_13b_lora_from_dir_test_harness(**llm_kwargs): assert similar(output.outputs[0].text, ref) -@pytest.mark.parametrize( - "lora_adapter_count_per_call, max_loras, max_cpu_loras, repeat_calls, repeats_per_call", - [ - # Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single - # llm.generate call, that's repeated twice. - ([ - 2, - ], 1, 2, 2, 3), - # Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU - # cache size < LoRA CPU cache size - ([2, 2, 2], 1, 3, 1, 1), - ]) -@skip_gpu_memory_less_than_40gb -def test_llama_7b_multi_lora_evict_load_new_adapters( +def _check_llama_7b_multi_lora_evict_load_new_adapters( lora_adapter_count_per_call: list[int], max_loras: int, max_cpu_loras: int, repeat_calls: int, repeats_per_call: int): # For LoRA checkpoints without finetuned embedding and lm_head, we can either: @@ -1493,6 +1480,43 @@ def test_llama_7b_multi_lora_evict_load_new_adapters( fast_build=True) +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(): + """Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single + llm.generate call, that's repeated twice. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2], + max_loras=1, + max_cpu_loras=2, + repeat_calls=2, + repeats_per_call=3) + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(): + """Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU + cache size < LoRA CPU cache size. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2, 2, 2], + max_loras=1, + max_cpu_loras=3, + repeat_calls=1, + repeats_per_call=1) + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_read_from_cache_after_insert(): + """Test that loading and then using the same adapters loaded in cache works.""" + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[3], + max_loras=3, + max_cpu_loras=3, + repeat_calls=2, + repeats_per_call=1) + + def test_llama_7b_peft_cache_config_affects_peft_cache_size(): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index acb831837c..440b368447 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -162,7 +162,8 @@ def test_KvCacheConfig_declaration(): secondary_offload_min_priority=1, event_buffer_max_size=0, enable_partial_reuse=True, - copy_on_partial_reuse=True) + copy_on_partial_reuse=True, + attention_dp_events_gather_period_ms=10) pybind_config = config._to_pybind() assert pybind_config.enable_block_reuse == True @@ -177,6 +178,7 @@ def test_KvCacheConfig_declaration(): assert pybind_config.event_buffer_max_size == 0 assert pybind_config.enable_partial_reuse == True assert pybind_config.copy_on_partial_reuse == True + assert pybind_config.attention_dp_events_gather_period_ms == 10 def test_KvCacheConfig_default_values(): @@ -664,15 +666,15 @@ class TestStrictBaseModelArbitraryArgs: def test_cache_transceiver_config_arbitrary_args(self): """Test that CacheTransceiverConfig rejects arbitrary arguments.""" # Valid arguments should work - config = CacheTransceiverConfig(backend="ucx", + config = CacheTransceiverConfig(backend="UCX", max_tokens_in_buffer=1024) - assert config.backend == "ucx" + assert config.backend == "UCX" assert config.max_tokens_in_buffer == 1024 # Arbitrary arguments should be rejected with pytest.raises( pydantic_core._pydantic_core.ValidationError) as exc_info: - CacheTransceiverConfig(backend="ucx", invalid_config="should_fail") + CacheTransceiverConfig(backend="UCX", invalid_config="should_fail") assert "invalid_config" in str(exc_info.value) def test_torch_compile_config_arbitrary_args(self): diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index 8f7fb75c7f..f505bd0383 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -1,6 +1,9 @@ import asyncio import time +import pytest +from utils.util import skip_single_gpu + import tensorrt_llm from tensorrt_llm import LLM from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest @@ -9,6 +12,7 @@ from tensorrt_llm._utils import KVCacheEventSerializer from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.sampling_params import SamplingParams +from tensorrt_llm.scheduling_params import SchedulingParams from .test_llm import get_model_path @@ -145,33 +149,52 @@ def test_kv_cache_event_async_api(): asyncio.run(main()) -def test_llm_kv_events_api(): - llm = create_llm() - sampling_params = SamplingParams(max_tokens=6, temperature=0.01) +def check_events(llm, + requests, + sampling_params, + scheduling_params=None, + attention_dp_rank=None): - requests = [] - for i in range(3): - input_tokens = list(range(127 + i))[i:] - requests.append(input_tokens) + _ = llm.generate(requests[0], + sampling_params=sampling_params, + scheduling_params=scheduling_params) + time.sleep(1) + events = llm.get_kv_cache_events(5) - _ = llm.generate(requests[0], sampling_params=sampling_params) - events1 = llm.get_kv_cache_events(5) + # Created or stored event + if attention_dp_rank is None: + event = events.pop(0) # created event + assert event["event_id"] == 0 + assert event["data"]["type"] == "created" + while events: + event = events.pop(0) + if event: + assert event["event_id"] == 1 + assert event["data"]["type"] == "stored" + assert len(event["data"]["blocks"]) == 5 + else: + while events: + event = events.pop(0) + assert "attention_dp_rank" in event + if event and event["attention_dp_rank"] == attention_dp_rank: + assert event["event_id"] in [0, 1] + assert event["data"]["type"] in ["created", "stored"] + if event["data"]["type"] == "created": + assert event["event_id"] == 0 + if event["data"]["type"] == "stored": + assert event["event_id"] == 1 + assert len(event["data"]["blocks"]) == 5 - # Should have 1 stored event and 1 created event - event = events1.pop(0) # created event - while events1: - event = events1.pop(0) - if event: - assert event["event_id"] == 1 - assert event["data"]["type"] == "stored" - assert len(event["data"]["blocks"]) == 5 - - _ = llm.generate(requests[1], sampling_params=sampling_params) + _ = llm.generate(requests[1], + sampling_params=sampling_params, + scheduling_params=scheduling_params) + time.sleep(1) events2 = llm.get_kv_cache_events(5) while events2: event = events2.pop(0) - if event: + if event and (attention_dp_rank is None + or event.get("attention_dp_rank") == attention_dp_rank): if event["event_id"] == 2: # 2 removed events needed # should be a removed event to make space for context block @@ -185,12 +208,16 @@ def test_llm_kv_events_api(): assert event["data"]["type"] == "stored" assert len(event["data"]["blocks"]) == 5 - _ = llm.generate(requests[2], sampling_params=sampling_params) + _ = llm.generate(requests[2], + sampling_params=sampling_params, + scheduling_params=scheduling_params) + time.sleep(1) events3 = llm.get_kv_cache_events(5) while events3: event = events3.pop(0) - if event: + if event and (attention_dp_rank is None + or event.get("attention_dp_rank") == attention_dp_rank): if event["event_id"] == 5: assert event["data"]["type"] == "removed" assert event["data"]["block_hashes"] @@ -203,3 +230,46 @@ def test_llm_kv_events_api(): # no more events after request is finished assert not llm.get_kv_cache_events(5) + + +@pytest.mark.skip(reason="https://nvbugs/5445001") +def test_llm_kv_events_api(): + llm = create_llm() + sampling_params = SamplingParams(max_tokens=6, + temperature=0.01, + ignore_eos=True) + + requests = [] + for i in range(3): + input_tokens = list(range(127 + i))[i:] + requests.append(input_tokens) + + check_events(llm, requests, sampling_params) + + +@pytest.mark.skip(reason="https://nvbugs/5451407") +@skip_single_gpu +@pytest.mark.threadleak(enabled=False) +def test_llm_api_attention_dp_kv_events(): + + llm = LLM(model=llama_model_path, + tensor_parallel_size=2, + enable_attention_dp=True, + kv_cache_config=global_kvcache_config, + enable_autotuner=False) + + sampling_params = SamplingParams(max_tokens=6, + temperature=0.01, + ignore_eos=True) + + for attention_dp_rank in range(2): + requests = [] + for i in range(3): + input_tokens = list(range(127 + i))[i:] + requests.append(input_tokens) + + scheduling_params = SchedulingParams( + attention_dp_rank=attention_dp_rank, attention_dp_relax=False) + + check_events(llm, requests, sampling_params, scheduling_params, + attention_dp_rank) diff --git a/tests/unittest/llmapi/test_llm_models.py b/tests/unittest/llmapi/test_llm_models.py index 737511d215..4fbc00ddf7 100644 --- a/tests/unittest/llmapi/test_llm_models.py +++ b/tests/unittest/llmapi/test_llm_models.py @@ -110,7 +110,6 @@ def test_llm_phi_3_mini_4k(): sampling_params=phi3_mini_4k_sampling_params) -@pytest.mark.skip(reason="https://nvbugs/5371480") @force_ampere def test_llm_phi_3_small_8k(): phi_requirement_path = os.path.join( diff --git a/tests/unittest/llmapi/test_llm_multi_gpu.py b/tests/unittest/llmapi/test_llm_multi_gpu.py index b498a8cd7f..a92e640a8b 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu.py @@ -12,7 +12,7 @@ from tensorrt_llm._tensorrt_engine import LLM from tensorrt_llm.executor import GenerationExecutorProxy from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import PretrainedConfig from tensorrt_llm.models.llama.model import LLaMAForCausalLM diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index cb8dbf03c0..28d6bedf1b 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -4,7 +4,7 @@ import pytest from .test_llm import tinyllama_logits_processor_test_harness from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_helper import LoraConfig from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness from .test_llm import _test_llm_capture_request_error diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 13708aae3c..6b78c46bd7 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -1,3 +1,4 @@ +import random from contextlib import contextmanager, nullcontext import pytest @@ -6,6 +7,7 @@ from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer +from tensorrt_llm.metrics import MetricNames from tensorrt_llm.sampling_params import SamplingParams # isort: off @@ -20,13 +22,11 @@ from .test_llm import (_test_llm_capture_request_error, get_model_path, run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler, tinyllama_logits_processor_test_harness) -from utils.util import (EnvVarsContextManager, force_ampere, - run_function_in_sub_process, similar, - skip_gpu_memory_less_than_40gb, +from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_gpu_memory_less_than_80gb, skip_gpu_memory_less_than_138gb) from utils.llm_data import llm_models_root -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo @@ -197,6 +197,27 @@ def test_llm_perf_metrics(): assert perf_metrics.last_iter == perf_metrics.iter +def test_llm_prometheus(): + test_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(max_tokens=10, temperature=0.8, top_p=0.95) + llm = LLM(model=llama_model_path, + return_perf_metrics=True, + kv_cache_config=global_kvcache_config) + for test_prompt in test_prompts: + request_output = llm.generate(test_prompt, sampling_params) + assert request_output.metrics_dict is not None + assert MetricNames.REQUEST_QUEUE_TIME in request_output.metrics_dict + assert MetricNames.TPOT in request_output.metrics_dict + assert MetricNames.TTFT in request_output.metrics_dict + assert MetricNames.E2E in request_output.metrics_dict + assert request_output.outputs is not None + + @pytest.mark.parametrize("streaming", [True, False]) def test_llm_with_postprocess_parallel_and_result_handler(streaming): run_llm_with_postprocess_parallel_and_result_handler(streaming, @@ -234,14 +255,11 @@ def test_embedding_bias_with_torch_sampler_strategies(enable_mixed_sampler, sampling_params = SamplingParams(**sampling_kwargs) - llm_test_harness( - llama_model_path, - prompts, - ["Z Z Z Z Z Z"], - sampling_params=sampling_params, - backend="pytorch", - enable_trtllm_sampler=False, # Use TorchSampler to test all 3 paths - enable_mixed_sampler=enable_mixed_sampler) + llm_test_harness(llama_model_path, + prompts, ["Z Z Z Z Z Z"], + sampling_params=sampling_params, + backend="pytorch", + enable_mixed_sampler=enable_mixed_sampler) def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None: @@ -313,20 +331,7 @@ def test_llama_7b_lora_default_modules() -> None: llm.shutdown() -@pytest.mark.parametrize( - "lora_adapter_count_per_call, max_loras, max_cpu_loras, repeat_calls, repeats_per_call", - [ - # Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single - # llm.generate call, that's repeated twice. - ([ - 2, - ], 1, 2, 2, 3), - # Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU - # cache size < LoRA CPU cache size - ([2, 2, 2], 1, 3, 1, 1), - ]) -@skip_gpu_memory_less_than_40gb -def test_llama_7b_multi_lora_evict_load_new_adapters( +def _check_llama_7b_multi_lora_evict_load_new_adapters( lora_adapter_count_per_call: list[int], max_loras: int, max_cpu_loras: int, repeat_calls: int, repeats_per_call: int): # For LoRA checkpoints without finetuned embedding and lm_head, we can either: @@ -347,60 +352,66 @@ def test_llama_7b_multi_lora_evict_load_new_adapters( cuda_graph_config=None) -@pytest.mark.parametrize( - "lora_adapter_count_per_call, max_loras, max_cpu_loras, repeat_calls, repeats_per_call", - [ - # Test eviction, reloading new adapters and reloading previously evicted adapters from the LoRA CPU cache & GPU - # cache over multiple llm.generate call repeated twice (two calls with the same requests): - # At the end of the 1st llm.generate call: - # The LoRA caches should contain adapters 1, 2 and shouldn't contain adapter 0 (it should have been evicted). - # So in the 2nd call, the worker should: - # - Send req0 with adapter 0 weights (because it was previously evicted) - # - Send the other two requests without their adapter weights as they're already in LoRA CPU cache - # Then, handling of req0 that has weights but not in the cache should evict one of the other two adapters from - # the cache, causing that evicted adapter's request to fail because its weights aren't with the request and - # aren't in LoRA cache. - ([ - 3, - ], 2, 2, 2, 1), - ]) @skip_gpu_memory_less_than_40gb -def test_llama_7b_multi_lora_load_previously_cpu_cache_evicted_adapter_fails( - lora_adapter_count_per_call: list[int], max_loras: int, - max_cpu_loras: int, repeat_calls: int, repeats_per_call: int): - """Tests that trying to load a LoRA adapter after it was evicted from CPU cache fails with the expected - message, as this feature is currently not supported in favor of the performance improvement of not - sending the LoRA weights with every request after the first time. - NOTE: This test assumes the requests are handled in the order they're sent, if that's not true, then this test - may not get any error at all, which would cause it to fail. +def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(): + """Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single + llm.generate call, that's repeated twice. """ # noqa: D205 - - def _check_contains_expected_message(stdout: str, stderr: str): - note_in_message = "Note that currently a request with LoRA task that was already loaded is sent" \ - " without its LoRA weights to save its serialization, copy and deserialization, so if this" \ - " LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported." - return note_in_message in stderr - - lora_config = LoraConfig(lora_target_modules=['attn_q', 'attn_k', 'attn_v'], - max_lora_rank=8, - max_loras=max_loras, - max_cpu_loras=max_cpu_loras) - with EnvVarsContextManager({"TLLM_WORKER_USE_SINGLE_PROCESS": "1"}): - child_stdout, child_stderr = run_function_in_sub_process( - target=check_llama_7b_multi_unique_lora_adapters_from_request, - args=(lora_adapter_count_per_call, repeat_calls, repeats_per_call, - LLM), - kwargs={ - "lora_config": lora_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - "cuda_graph_config": None - }, - stop_waiting_criteria=_check_contains_expected_message) - - assert _check_contains_expected_message(child_stdout, child_stderr) + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2], + max_loras=1, + max_cpu_loras=2, + repeat_calls=2, + repeats_per_call=3) +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(): + """Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU + cache size < LoRA CPU cache size. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2, 2, 2], + max_loras=1, + max_cpu_loras=3, + repeat_calls=1, + repeats_per_call=1) + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_read_from_cache_after_insert(): + """Test that loading and then using the same adapters loaded in cache works.""" + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[3], + max_loras=3, + max_cpu_loras=3, + repeat_calls=2, + repeats_per_call=1) + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_cache( +): + """Test eviction, reloading new adapters and reloading previously evicted adapters from the LoRA CPU cache & GPU + cache over multiple llm.generate call repeated twice (two calls with the same requests): + At the end of the 1st llm.generate call: + The LoRA caches should contain adapters 1, 2 and shouldn't contain adapter 0 (it should have been evicted). + So in the 2nd call, the worker should: + - Send req0 with adapter 0 weights (because it was previously evicted) + - Send the other two requests without their adapter weights as they're already in LoRA CPU cache + Then, handling of req0 that has weights but not in the cache should evict one of the other two adapters from + the cache, causing that evicted adapter's request to again load its weights from the file system, as they + aren't with the request and aren't in LoRA cache. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[3], + max_loras=2, + max_cpu_loras=2, + repeat_calls=2, + repeats_per_call=1) + + +@skip_gpu_memory_less_than_40gb def test_llama_7b_peft_cache_config_affects_peft_cache_size(): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. @@ -436,6 +447,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): cuda_graph_config=None) +@skip_gpu_memory_less_than_40gb def test_llama_7b_lora_config_overrides_peft_cache_config(): """Tests that cache size args in lora_config LLM arg override the cache size parameters in peft_cache_config LLM arg. @@ -457,6 +469,7 @@ def test_llama_7b_lora_config_overrides_peft_cache_config(): # TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high # https://jirasw.nvidia.com/browse/TRTLLM-5045 +@pytest.mark.skip(reason="https://nvbugs/5448464") @skip_gpu_memory_less_than_138gb def test_nemotron_nas_lora() -> None: lora_config = LoraConfig(lora_dir=[ @@ -791,3 +804,17 @@ def test_gqa_nemo_lora(tmp_path): f"got: {base_outputs[0].outputs[0].text}" finally: llm.shutdown() + + +class TestLlmError: + + def test_max_num_token_check(self): + """ LLM should raise error when got prompt length exceed the valid range. """ + llm = LLM(llama_model_path, + kv_cache_config=global_kvcache_config, + max_num_tokens=100) + + with pytest.raises(ValueError, + match="should not exceed max_num_tokens"): + ids = [random.randint(10, 100) for _ in range(101)] + llm.generate([ids]) diff --git a/tests/unittest/llmapi/test_utils.py b/tests/unittest/llmapi/test_utils.py index d742283ca5..fc5876cdb1 100644 --- a/tests/unittest/llmapi/test_utils.py +++ b/tests/unittest/llmapi/test_utils.py @@ -13,7 +13,9 @@ def test_api_status_registry(): def _my_method(self, *args, **kwargs): pass - assert ApiStatusRegistry.get_api_status(_my_method) == "prototype" + # will always keep the first status, and the behaviour will be unknown if + # one method is registered with a different status in different files. + assert ApiStatusRegistry.get_api_status(_my_method) == "beta" class App: diff --git a/tests/unittest/others/test_mapping.py b/tests/unittest/others/test_mapping.py index 6d836f220b..bc9839239b 100644 --- a/tests/unittest/others/test_mapping.py +++ b/tests/unittest/others/test_mapping.py @@ -44,3 +44,40 @@ class TestMapping(unittest.TestCase): self.assertTrue(m.is_last_pp_rank()) self.assertEqual(m.prev_pp_rank(), 4) self.assertEqual(m.next_pp_rank(), 0) + + m = Mapping(world_size=2, rank=0, cp_size=2) + self.assertEqual(len(m.tp_groups), 2) + self.assertEqual(len(m.pp_groups), 2) + self.assertEqual(len(m.cp_groups), 1) + self.assertEqual(m.tp_group, [0]) + self.assertEqual(m.pp_group, [0]) + self.assertEqual(m.cp_group, [0, 1]) + + m = Mapping(world_size=8, rank=3, tp_size=2, pp_size=2, cp_size=2) + self.assertEqual(len(m.tp_groups), 4) + self.assertEqual(len(m.pp_groups), 4) + self.assertEqual(len(m.cp_groups), 4) + self.assertEqual(m.tp_group, [2, 3]) + self.assertEqual(m.pp_group, [3, 7]) + self.assertEqual(m.cp_group, [1, 3]) + self.assertTrue(m.is_first_pp_rank()) + self.assertFalse(m.is_last_pp_rank()) + self.assertFalse(m.is_first_cp_rank()) + self.assertTrue(m.is_last_cp_rank()) + self.assertEqual(m.prev_pp_rank(), 7) + self.assertEqual(m.next_pp_rank(), 7) + self.assertEqual(m.prev_cp_rank(), 1) + self.assertEqual(m.next_cp_rank(), 1) + + m = Mapping(world_size=16, rank=9, tp_size=2, pp_size=2, cp_size=4) + self.assertEqual(m.tp_group, [8, 9]) + self.assertEqual(m.pp_group, [1, 9]) + self.assertEqual(m.cp_group, [9, 11, 13, 15]) + self.assertFalse(m.is_first_pp_rank()) + self.assertTrue(m.is_last_pp_rank()) + self.assertTrue(m.is_first_cp_rank()) + self.assertFalse(m.is_last_cp_rank()) + self.assertEqual(m.prev_pp_rank(), 1) + self.assertEqual(m.next_pp_rank(), 1) + self.assertEqual(m.prev_cp_rank(), 15) + self.assertEqual(m.next_cp_rank(), 11) diff --git a/tests/unittest/others/test_multimodal_registry.py b/tests/unittest/others/test_multimodal_registry.py new file mode 100644 index 0000000000..a6eb5345e0 --- /dev/null +++ b/tests/unittest/others/test_multimodal_registry.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement) + + +class TestMultimodalPlaceholderRegistry(unittest.TestCase): + + def setUp(self): + self.model_type = "test_model_type" + self.placeholder_metadata = MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "IMAGE_PLACEHOLDER", + "video": "VIDEO_PLACEHOLDER" + }, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + placeholders_separator="\n") + + def test_new_registration(self): + MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata( + self.model_type, self.placeholder_metadata) + self.assertEqual( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder_metadata( + self.model_type), self.placeholder_metadata) + MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata( + self.model_type) + + def test_registered_model_types(self): + pre_reg_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()) + + # register the model type + MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata( + self.model_type, self.placeholder_metadata) + + post_reg_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()) + self.assertEqual( + len(pre_reg_model_types) + 1, len(post_reg_model_types)) + self.assertIn(self.model_type, post_reg_model_types) + + MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata( + self.model_type) + + def test_validity(self): + MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata( + self.model_type, self.placeholder_metadata) + self.assertTrue( + MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "image")) + self.assertTrue( + MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "video")) + self.assertFalse( + MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "audio")) + + MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata( + self.model_type) + + def test_model_types_per_modality(self): + pre_reg_image_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types()) + pre_reg_video_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types()) + pre_reg_audio_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types()) + + # register the model type for image and video + MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata( + self.model_type, self.placeholder_metadata) + + post_reg_image_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types()) + post_reg_video_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types()) + post_reg_audio_model_types = list( + MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types()) + self.assertEqual( + len(pre_reg_image_model_types) + 1, len(post_reg_image_model_types)) + self.assertEqual( + len(pre_reg_video_model_types) + 1, len(post_reg_video_model_types)) + self.assertEqual(len(pre_reg_audio_model_types), + len(post_reg_audio_model_types)) + self.assertIn(self.model_type, post_reg_image_model_types) + self.assertIn(self.model_type, post_reg_video_model_types) + self.assertNotIn(self.model_type, post_reg_audio_model_types) + + MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata( + self.model_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/test_pip_install.py b/tests/unittest/test_pip_install.py index d75bfbaf42..11288e09ce 100644 --- a/tests/unittest/test_pip_install.py +++ b/tests/unittest/test_pip_install.py @@ -51,9 +51,6 @@ def test_pip_install(): help="The wheel path") args = parser.parse_args() - if not os.environ.get("CUDA_HOME"): - os.environ["CUDA_HOME"] = "/usr/local/cuda" - print("########## Install required system libs ##########") if not os.path.exists("/usr/local/mpi/bin/mpicc"): subprocess.check_call("apt-get -y install libopenmpi-dev", shell=True) diff --git a/tests/unittest/trt/attention/test_gpt_attention.py b/tests/unittest/trt/attention/test_gpt_attention.py index cbe5c1309e..38638b198d 100644 --- a/tests/unittest/trt/attention/test_gpt_attention.py +++ b/tests/unittest/trt/attention/test_gpt_attention.py @@ -873,18 +873,17 @@ class TestFunctional(unittest.TestCase): ConfigCls = GPTBigCodeConfig AttentionCls = GPTBigCodeAttention - configuration = ConfigCls( - hidden_size=hidden_size, - num_hidden_layers=1, - num_attention_heads=num_heads, - vocab_size=51200, - use_cache=True, - resid_pdrop=0, - embd_pdrop=0, - attn_pdrop=0, - hidden_act='gelu', - torch_dtype=dtype, - ) + configuration = ConfigCls(hidden_size=hidden_size, + num_hidden_layers=1, + num_attention_heads=num_heads, + vocab_size=51200, + use_cache=True, + resid_pdrop=0, + embd_pdrop=0, + attn_pdrop=0, + hidden_act='gelu', + torch_dtype=dtype, + attn_implementation='eager') if attention_type in ['gptj_attention', 'llama_attention']: configuration.rotary_dim = head_size @@ -1259,38 +1258,35 @@ class TestFunctional(unittest.TestCase): use_cache=True)[0] torch_present = torch_present.to_legacy_cache() elif attention_type == 'gptj_attention': - torch_output, torch_present = attention( - input_tensor, - layer_past=None, - position_ids=position_ids, - attention_mask=attention_mask, - use_cache=True) + torch_present = DynamicCache() + torch_output = attention(input_tensor, + layer_past=torch_present, + position_ids=position_ids, + attention_mask=attention_mask, + use_cache=True)[0] + torch_present = torch_present.to_legacy_cache() elif attention_type == 'gpt_bigcode_attention': attention_mask = _prepare_4d_attention_mask( ctx_attention_mask, dtype=str_dtype_to_torch(dtype), tgt_len=in_len) - # source shape = (b, 1, s_query, s_key) - # target shape = (b, s_query, h, s_key) + # target shape = (b, h, s_query, s_key) attention_mask = (attention_mask - >= 0).permute([0, 2, 1, 3]).expand( - batch_size, in_len, num_heads, in_len) - torch_output, torch_present = attention( - input_tensor, - layer_past=None, - attention_mask=attention_mask, - use_cache=True) + >= 0).expand(batch_size, num_heads, + in_len, in_len) + torch_present = DynamicCache() + torch_output = attention(input_tensor, + layer_past=torch_present, + attention_mask=attention_mask, + use_cache=True)[0] + torch_present = torch_present.to_legacy_cache() else: raise RuntimeError("attention_type not properly set") torch.cuda.synchronize() - if attention_type in ['llama_attention', 'gpt2_attention']: - kv_dequant_scale, kv_quant_scale = get_kv_quant_scale( - torch_present[0]) - else: - kv_dequant_scale, kv_quant_scale = get_kv_quant_scale( - torch_present) + kv_dequant_scale, kv_quant_scale = get_kv_quant_scale( + torch_present[0]) if enable_remove_input_padding: shape_dict['input'] = (batch_size * (in_len // 2), @@ -1330,10 +1326,7 @@ class TestFunctional(unittest.TestCase): torch_output[:, :in_len // 2, :].to( torch.float32).cpu().numpy(), atol=5e-3) - if attention_type in ['llama_attention', 'gpt2_attention']: - verify_kv_cache(torch_present[0]) - else: - verify_kv_cache(torch_present) + verify_kv_cache(torch_present[0]) else: # Generation stage @@ -1408,24 +1401,27 @@ class TestFunctional(unittest.TestCase): use_cache=True)[0] torch_present = torch_present.to_legacy_cache() elif attention_type == 'gptj_attention': - torch_output, torch_present = attention( - input_tensor, - layer_past=torch_present, - position_ids=position_ids, - attention_mask=attention_mask, - use_cache=True) + torch_present = DynamicCache.from_legacy_cache( + torch_present) + torch_output = attention(input_tensor, + layer_past=torch_present, + position_ids=position_ids, + attention_mask=attention_mask, + use_cache=True)[0] + torch_present = torch_present.to_legacy_cache() elif attention_type == 'gpt_bigcode_attention': - # source shape = (b, 1, 1, s_key) - # target shape = (b, 1, h, s_key) + # target shape = (b, h, 1, s_key) key_seqlen = in_len + step # ctx_attention_mask.shape[1] attention_mask = (attention_mask - >= 0).permute([0, 2, 1, 3]).expand( - batch_size, 1, num_heads, key_seqlen) - torch_output, torch_present = attention( - input_tensor, - layer_past=torch_present, - use_cache=True, - attention_mask=attention_mask) + >= 0).expand(batch_size, num_heads, 1, + key_seqlen) + torch_present = DynamicCache.from_legacy_cache( + torch_present) + torch_output = attention(input_tensor, + layer_past=torch_present, + use_cache=True, + attention_mask=attention_mask)[0] + torch_present = torch_present.to_legacy_cache() def tile_beam_width(tensor: torch.Tensor, num_beams: int): if num_beams == 1: diff --git a/tests/unittest/trt/attention/test_gpt_attention_IFB.py b/tests/unittest/trt/attention/test_gpt_attention_IFB.py index 68c45583ab..9af84f05a3 100644 --- a/tests/unittest/trt/attention/test_gpt_attention_IFB.py +++ b/tests/unittest/trt/attention/test_gpt_attention_IFB.py @@ -566,18 +566,17 @@ class TestFunctional(unittest.TestCase): ConfigCls = GPTBigCodeConfig AttentionCls = GPTBigCodeAttention - configuration = ConfigCls( - hidden_size=hidden_size, - num_hidden_layers=1, - num_attention_heads=num_heads, - vocab_size=51200, - use_cache=True, - resid_pdrop=0, - embd_pdrop=0, - attn_pdrop=0, - hidden_act='gelu', - torch_dtype=dtype, - ) + configuration = ConfigCls(hidden_size=hidden_size, + num_hidden_layers=1, + num_attention_heads=num_heads, + vocab_size=51200, + use_cache=True, + resid_pdrop=0, + embd_pdrop=0, + attn_pdrop=0, + hidden_act='gelu', + torch_dtype=dtype, + attn_implementation='eager') if attention_type == 'llama_attention': configuration.num_key_value_heads = num_kv_heads configuration.rope_theta = rope_base @@ -787,18 +786,19 @@ class TestFunctional(unittest.TestCase): position_ids=position_ids, attention_mask=attention_mask, use_cache=True) + torch_present = layer_past elif attention_type == 'gpt_bigcode_attention': - # source shape = (b, 1, s_query or 1, s_key) - # target shape = (b, s_query or 1, h, s_key) - attention_mask = (attention_mask >= 0).permute( - [0, 2, 1, - 3]).expand(input.shape[0], in_len if step == 0 else 1, - num_heads, in_len + step) + # target shape = (b, h, s_query or 1, s_key) + attention_mask = (attention_mask + >= 0).expand(input.shape[0], num_heads, + in_len if step == 0 else 1, + in_len + step) torch_output, torch_present = attention( input, layer_past=layer_past, attention_mask=attention_mask, use_cache=True) + torch_present = layer_past else: raise RuntimeError("attention_type not properly set") @@ -1010,23 +1010,16 @@ class TestFunctional(unittest.TestCase): (local_beam_width, input_length, hidden_size)) # llama/gpt2 uses DynamicCache - if attention_type in ['llama_attention', 'gpt2_attention']: - past_key_values = DynamicCache.from_legacy_cache( - torch_cache_list[req_idx]) - else: - past_key_values = torch_cache_list[req_idx] + past_key_values = DynamicCache.from_legacy_cache( + torch_cache_list[req_idx]) torch_out, past_key_values = torch_exec( step, torch_in, ctx_attention_mask_list[req_idx], req_idx, past_key_values) # llama/gpt2 uses DynamicCache - if attention_type in ['llama_attention', 'gpt2_attention']: - torch_cache_list[req_idx] = past_key_values.to_legacy_cache( - ) - past_key_values = torch_cache_list[req_idx][0] - else: - torch_cache_list[req_idx] = past_key_values + torch_cache_list[req_idx] = past_key_values.to_legacy_cache() + past_key_values = torch_cache_list[req_idx][0] if use_fp8_kv_cache or use_int8_kv_cache: max_kv_cache = max( diff --git a/tests/unittest/trt/functional/test_moe.py b/tests/unittest/trt/functional/test_moe.py index 120905e834..e5dcefcbaa 100644 --- a/tests/unittest/trt/functional/test_moe.py +++ b/tests/unittest/trt/functional/test_moe.py @@ -1188,7 +1188,7 @@ class TestMoE(unittest.TestCase): moe_weight_wrapper.weights_block_scaling_factor_interleaved.value = ( np.ascontiguousarray( torch_to_numpy( - torch.ops.trtllm.nvfp4_block_scale_interleave( + torch.ops.trtllm.block_scale_interleave( scale_factor.view(torch.uint8).contiguous()).view( scale_factor.dtype).reshape( scale_factor.shape).view(torch.uint8)))) diff --git a/tests/unittest/trt/model/test_nemotron_nas.py b/tests/unittest/trt/model/test_nemotron_nas.py index 07e1c1939d..5a99291ddf 100644 --- a/tests/unittest/trt/model/test_nemotron_nas.py +++ b/tests/unittest/trt/model/test_nemotron_nas.py @@ -382,6 +382,7 @@ class TestNemotronNas(unittest.TestCase): @parameterized.expand(get_loader_test_cases, name_func=unittest_name_func) def test_allclose_to_hf(self, hf_model_dir: str, params: TestParams): + self.skipTest(f"https://nvbugs/5444611") hf_model = transformers.AutoModelForCausalLM.from_pretrained( hf_model_dir, trust_remote_code=True, @@ -827,6 +828,7 @@ class TestNemotronNas(unittest.TestCase): def test_convert_model_from_hf(self, model_dir: Optional[str], preloaded: bool, tp_size: int, pp_size: int, dtype: str) -> None: + self.skipTest(f"https://nvbugs/5444611") ckpt_path = Path(llm_models_root(check=True), "nvsmall/tests", model_dir) diff --git a/tests/unittest/trt/quantization/test_mode.py b/tests/unittest/trt/quantization/test_mode.py index 42f8a5a7f1..d211a4a0ed 100644 --- a/tests/unittest/trt/quantization/test_mode.py +++ b/tests/unittest/trt/quantization/test_mode.py @@ -44,7 +44,7 @@ class TestQuantMode(unittest.TestCase): def test_count(self): # Make sure the COUNT value is as expected - change that test if you add a new flag. - self.assertEqual(QuantMode.COUNT.value, 1 << 15) + self.assertEqual(QuantMode.COUNT.value, 1 << 17) def test_from_description(self): # Test weight only. diff --git a/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py b/tests/unittest/trt/quantization/test_moe_weight_only_quant_matmul.py similarity index 60% rename from tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py rename to tests/unittest/trt/quantization/test_moe_weight_only_quant_matmul.py index c48326e205..d741978c2a 100644 --- a/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py +++ b/tests/unittest/trt/quantization/test_moe_weight_only_quant_matmul.py @@ -14,29 +14,33 @@ # limitations under the License. import unittest +import pytest + # isort: off import torch # isort: on from parameterized import parameterized -from utils.util import (create_session, run_session, skip_non_ada_unittest, +from utils.util import (create_session, run_session, + skip_neither_ada_nor_hopper_unittest, unittest_name_func) import tensorrt_llm import tensorrt_llm.quantization.functional from tensorrt_llm import Tensor -from tensorrt_llm._utils import (str_dtype_to_trt, torch_to_numpy, - trt_dtype_to_str) +from tensorrt_llm._utils import (get_sm_version, str_dtype_to_trt, + torch_to_numpy, trt_dtype_to_str) from tensorrt_llm.layers.moe import MoeConfig from tensorrt_llm.quantization import QuantMode from . import _utils -class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): +class TestMoEWeightOnlyQuantMatmul(unittest.TestCase): def setUp(self): torch.manual_seed(0) + torch.cuda.manual_seed(0) tensorrt_llm.logger.set_level('error') def create_trt_session( @@ -64,11 +68,15 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): network = builder.create_network() dtype = str_dtype_to_trt(str_dtype) norm_mode = MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE - quant_mode = QuantMode.use_weight_only(True, True) k = act.shape[1] - n = weight_scaling_factor_1.shape[-1] // 2 + if has_pre_quant: + n = fc2_prequant_scale.shape[-1] + else: + n = weight_scaling_factor_1.shape[-1] // 2 num_experts = weight_scaling_factor_1.shape[0] + use_int8 = True if self.quant_mode.is_int8_weight_only() else False + with tensorrt_llm.net_guard(network): trt_key = Tensor(name='input_hidden_states', shape=act.shape, @@ -87,18 +95,25 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): hidden_act="swiglu", bias=False, dtype=dtype, - quant_mode=quant_mode, + quant_mode=self.quant_mode, pre_quant_scale=has_pre_quant, zero=has_zero, use_w4a8_awq=has_alpha, + use_int8_weight=use_int8, group_size=group_size) moe.router.weight.value = torch_to_numpy(router.cpu()) moe.fc.weight.value = torch_to_numpy(fc1_weights.cpu()) moe.proj.weight.value = torch_to_numpy(fc2_weights.cpu()) - moe.fc.weights_scaling_factor.value = torch_to_numpy( - weight_scaling_factor_1.cpu()) - moe.proj.weights_scaling_factor.value = torch_to_numpy( - weight_scaling_factor_2.cpu()) + if group_size != -1: + moe.fc.weights_scaling_factor.value = torch_to_numpy( + weight_scaling_factor_1.cpu()) + moe.proj.weights_scaling_factor.value = torch_to_numpy( + weight_scaling_factor_2.cpu()) + else: + moe.fc.per_channel_scale.value = torch_to_numpy( + weight_scaling_factor_1.cpu()) + moe.proj.per_channel_scale.value = torch_to_numpy( + weight_scaling_factor_2.cpu()) if has_pre_quant: moe.fc.prequant_scaling_factor.value = torch_to_numpy( fc1_prequant_scale.cpu()) @@ -117,8 +132,8 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): session = create_session(builder, network, precision=trt_dtype_to_str(dtype), - int8=False, - quant_mode=quant_mode) + int8=use_int8, + quant_mode=self.quant_mode) return session def _woq_moe_groupwise_matmul(self, @@ -202,18 +217,50 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): ref_weight_1 += zero_1.repeat_interleave(group_size, dim=1) ref_weight_2 += zero_2.repeat_interleave(group_size, dim=1) activation_type = torch.float8_e4m3fn if has_alpha else activation_dtype + # Hopper w4a8 does not interleave weight + do_weight_interleave = get_sm_version() != 90 or not has_alpha cuda_q_weight_1 = preprocessor( - unprocessed_weight_1.cpu(), quantized_weight_dtype, - activation_type).view(activation_dtype).cpu() + unprocessed_weight_1.cpu(), + quantized_weight_dtype, + activation_type, + do_weight_interleave=do_weight_interleave).view( + activation_dtype).cpu() cuda_q_weight_2 = preprocessor( - unprocessed_weight_2.cpu(), quantized_weight_dtype, - activation_type).view(activation_dtype).cpu() - if has_alpha and activation_dtype == torch.bfloat16: + unprocessed_weight_2.cpu(), + quantized_weight_dtype, + activation_type, + do_weight_interleave=do_weight_interleave).view( + activation_dtype).cpu() + if get_sm_version() == 89 and has_alpha: scale_1 = scale_1.to(torch.float16).view(activation_dtype) scale_2 = scale_2.to(torch.float16).view(activation_dtype) zero_1 = zero_1.to(torch.float16).view(activation_dtype) zero_2 = zero_2.to(torch.float16).view(activation_dtype) + if get_sm_version() == 90 and has_alpha: + if has_zero: + pytest.skip( + "has_zero is not supported in Hopper with WINT4AFP8.") + + def interleave_scales(scales: torch.Tensor, interleave_dim: int): + # [num_experts, num_groups, num_cols] --> [num_experts, num_groups // interleave, num_cols * interleave] + # Note: num_groups = num_rows // group_size + E, G, C = scales.shape + I = tensorrt_llm.quantization.functional.get_weight_scale_interleave_factor( + interleave_dim, group_size) + assert G % I == 0, f"Group dimension ({G}) must be divisible by interleave factor ({I})." + scales_interleaved = scales.reshape(E, G // I, I, C) + scales_interleaved = scales_interleaved.permute(0, 1, 3, 2) + scales_interleaved = scales_interleaved.reshape( + E, G // I, C * I) + return scales_interleaved.contiguous() + + scale_1 = scale_1.to(torch.bfloat16).view(activation_dtype) + scale_2 = scale_2.to(torch.bfloat16).view(activation_dtype) + scale_1 = interleave_scales(scale_1, k) + scale_2 = interleave_scales(scale_2, n) + zero_1, zero_2 = None, None + session = self.create_trt_session( activation_dtype_str, activation, router, pre_quant_scale_1, pre_quant_scale_2, cuda_q_weight_1, cuda_q_weight_2, scale_1, @@ -262,6 +309,97 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): ref = results.view(*inputs.shape) _utils.woq_assert_near_eq(ref, out, 2) + def _woq_moe_matmul_per_channel(self, + m, + n, + k, + num_experts, + activation_dtype_str, + quantized_weight_dtype, + top_k=2): + + activation_dtype = tensorrt_llm._utils.str_dtype_to_torch( + activation_dtype_str) + activation = torch.randn(m, k, dtype=activation_dtype, device="cuda") + router = torch.randn((num_experts, k), + dtype=torch.float32, + device="cuda") + + num_weights_in_32_bits = 4 + + assert n % num_weights_in_32_bits == 0, f"n must be a multiple of {num_weights_in_32_bits}" + unprocessed_int_weight_1 = torch.randint( + -2**31, + 2**31, (num_experts, k, n * 2 // num_weights_in_32_bits), + dtype=torch.int32, + device="cuda") + unprocessed_int_weight_2 = torch.randint( + -2**31, + 2**31, (num_experts, n, k // num_weights_in_32_bits), + dtype=torch.int32, + device="cuda") + unprocessed_weight_1 = unprocessed_int_weight_1.view(torch.int8) + unprocessed_weight_2 = unprocessed_int_weight_2.view(torch.int8) + + scale_1 = torch.randn( + num_experts, 1, n * 2, dtype=activation_dtype, device="cuda") / k + scale_2 = torch.randn( + num_experts, 1, k, dtype=activation_dtype, device="cuda") / n + + ref_weight_1 = unprocessed_weight_1 * scale_1 + ref_weight_2 = unprocessed_weight_2 * scale_2 + scale_1 = scale_1.squeeze(1) + scale_2 = scale_2.squeeze(1) + + preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm + + cuda_q_weight_1 = preprocessor(unprocessed_weight_1.cpu(), + quantized_weight_dtype, + activation_dtype).cpu() + cuda_q_weight_2 = preprocessor(unprocessed_weight_2.cpu(), + quantized_weight_dtype, + activation_dtype).cpu() + + session = self.create_trt_session(activation_dtype_str, activation, + router, None, None, cuda_q_weight_1, + cuda_q_weight_2, scale_1, scale_2, + None, None, None, None, top_k, False, + False, False, -1) + + inputs = {"input_hidden_states": activation} + outputs = run_session(session, inputs) + out = outputs['output'].float() + + # ref + inputs = activation.cuda().float() + inputs_merged = inputs.view(-1, inputs.shape[-1]) + routing = torch.matmul(inputs_merged, router.T.float()) + router_probs = torch.softmax(routing, 1, dtype=inputs.dtype) + topk = torch.topk(router_probs, top_k) + results = torch.zeros_like(inputs_merged) + for i, (scales, experts) in enumerate(zip(topk.values, topk.indices)): + scales /= sum(scales) + input = inputs_merged[i, :] + for scale, expert in zip(scales, experts): + input = inputs_merged[i, :] + fc1_qd = ref_weight_1[expert].cuda().float() + fc1 = torch.matmul(input, fc1_qd) + fc1, gate = fc1.chunk(2, dim=-1) + fc1 = fc1 * torch.nn.functional.silu(gate) + + fc2_qd = ref_weight_2[expert].cuda().float() + final = torch.matmul(fc1, fc2_qd) + results[i] += scale * final + ref = results.view(*inputs.shape) + _utils.woq_assert_near_eq(ref, out, 1) + + @parameterized.expand([(1, 14336, 4096, 8, "float16"), + (1, 14336, 4096, 8, "bfloat16")], + name_func=unittest_name_func) + def test_moe_w8a16(self, m, n, k, experts, dtype): + self.quant_mode = QuantMode.use_weight_only(False, False) + self._woq_moe_matmul_per_channel(m, n, k, experts, dtype, torch.int8) + @parameterized.expand([(1, 14336, 4096, 8, "float16", True, True), (1, 14336, 4096, 8, "float16", True, False), (1, 14336, 4096, 8, "float16", False, True), @@ -269,7 +407,9 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): (1, 14336, 4096, 8, "bfloat16", True, False), (1, 14336, 4096, 8, "bfloat16", False, True)], name_func=unittest_name_func) - def test_moe_w4a16(self, m, n, k, experts, dtype, has_pre_quant, has_zero): + def test_moe_w4a16_groupwise(self, m, n, k, experts, dtype, has_pre_quant, + has_zero): + self.quant_mode = QuantMode.use_weight_only(True, True) self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2, has_pre_quant, has_zero, False) @@ -278,9 +418,10 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): (1, 14336, 4096, 8, "bfloat16", True, False), (1, 14336, 4096, 8, "bfloat16", True, True)], name_func=unittest_name_func) - @skip_non_ada_unittest - def test_moe_w4a8(self, m, n, k, experts, dtype, has_pre_quant, has_zero): - + @skip_neither_ada_nor_hopper_unittest + def test_moe_w4a8_groupwise(self, m, n, k, experts, dtype, has_pre_quant, + has_zero): + self.quant_mode = QuantMode.use_weight_only(True, True) self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2, has_pre_quant, has_zero, True) diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index 07cd4c20cc..893af5d93b 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -1,13 +1,9 @@ -import multiprocessing import os -import sys -import time import unittest from contextlib import contextmanager from difflib import SequenceMatcher -from multiprocessing.connection import Connection from pathlib import Path -from typing import Any, Callable, Generator, Mapping, Tuple +from typing import Any, Generator import pynvml import pytest @@ -433,88 +429,16 @@ def duplicate_list_to_length(list: list[Any], target_length: int) -> list[Any]: return duplicated_list -def _target_wrapper(target: Callable, stdout_pipe: Connection, - stderr_pipe: Connection, *args, **kwargs) -> None: - - class PipeWriter: - - def __init__(self, conn: Connection): - self.conn = conn - - def write(self, s: str): - self.conn.send_bytes(s.encode("UTF8")) - - def flush(self): - pass - - sys.stdout = PipeWriter(stdout_pipe) - sys.stderr = PipeWriter(stderr_pipe) - target(*args, **kwargs) - - -def run_function_in_sub_process(target: Callable, - args: tuple, - kwargs: Mapping[str, Any], - stop_waiting_criteria: Callable, - poll_interval_seconds: int = 5, - timeout_seconds: int = 240) -> Tuple[str, str]: - multiprocessing.set_start_method("spawn", force=True) - parent_stdout_pipe, child_stdout_pipe = multiprocessing.Pipe() - parent_stderr_pipe, child_stderr_pipe = multiprocessing.Pipe() - child_process = multiprocessing.Process( - target=_target_wrapper, - args=[target, child_stdout_pipe, child_stderr_pipe] + list(args), - kwargs=kwargs) - child_process.start() - child_stdout_pipe.close() - child_stderr_pipe.close() - - def _read_from_pipe(pipe: Connection): - out = "" - while pipe.poll(timeout=0.1): - try: - out += pipe.recv_bytes().decode("UTF8") - except Exception: - break - return out - - child_stdout = "" - child_stderr = "" - try: - total_waiting_seconds = 0 - while child_process.is_alive( - ) and total_waiting_seconds < timeout_seconds: - child_stdout += _read_from_pipe(parent_stdout_pipe) - child_stderr += _read_from_pipe(parent_stderr_pipe) - if stop_waiting_criteria(child_stdout, child_stderr): - break - time.sleep(poll_interval_seconds) - total_waiting_seconds += poll_interval_seconds - finally: - parent_stdout_pipe.close() - parent_stderr_pipe.close() - if child_process.is_alive(): - child_process.terminate() - - assert total_waiting_seconds < timeout_seconds, "Reached timeout while waiting for target" - return child_stdout, child_stderr - - -class EnvVarsContextManager: - - def __init__(self, new_env_vars: dict[str, str]): - self._env_vars = new_env_vars - self._original_value = None - - def __enter__(self): - self._original_vars = { - var_name: os.environ[var_name] - for var_name in self._env_vars.keys() if var_name in os.environ - } - os.environ.update(self._env_vars) - - def __exit__(self, type, value, traceback): - os.environ.update(self._original_vars) - for var_name in self._env_vars.keys(): - if var_name not in self._original_vars: - os.environ.pop(var_name) +# Check a certain percentage of elements in two tensors are within a tolerance +def check_accuracy(a, b, atol, rtol, percent): + assert a.shape == b.shape + assert a.dtype == b.dtype + a = a.to(torch.float32) + b = b.to(torch.float32) + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if not (mismatch_percent < 1 - percent): + raise Exception("Mismatch percentage is %f for rtol %f" % + (mismatch_percent, rtol)) diff --git a/triton_backend/requirements.txt b/triton_backend/requirements.txt index 3fcf5762e1..27cb5d4105 100644 --- a/triton_backend/requirements.txt +++ b/triton_backend/requirements.txt @@ -1,7 +1,7 @@ regex fire tritonclient[all] -transformers==4.51.0 +transformers==4.55.0 pandas tabulate flash_attn diff --git a/triton_backend/tools/tests/test_llmapi_cancel.py b/triton_backend/tools/tests/test_llmapi_cancel.py new file mode 100644 index 0000000000..4cd8c0c606 --- /dev/null +++ b/triton_backend/tools/tests/test_llmapi_cancel.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from functools import partial + +import numpy as np +from tritonclient import grpc as grpcclient +from tritonclient.utils import InferenceServerException + +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/..') +from llmapi_client import (UserData, _prepare_inputs, callback, + prepare_stop_signals) + +if __name__ == "__main__": + input_data = np.array([ + "The current time is", + ], dtype=object) + output_len = 100 + inputs = _prepare_inputs(input_data, output_len) + + stop_inputs = prepare_stop_signals() + request_id = 1 + user_data = UserData() + with grpcclient.InferenceServerClient( + url="localhost:8001", + verbose=False, + ssl=False, + root_certificates=None, + private_key=None, + certificate_chain=None, + ) as triton_client: + + # Send stop request for non-existing request + triton_client.async_infer( + "tensorrt_llm", + stop_inputs, + request_id=str(request_id), # Request does not exist yet + callback=partial(callback, user_data), + parameters={'Streaming': False}) + + result = user_data._completed_requests.get() + assert isinstance(result, InferenceServerException) + assert result.status() == "StatusCode.CANCELLED" + + # Send actual request + infer_response = triton_client.async_infer( + "tensorrt_llm", + inputs, + request_id=str(request_id), + callback=partial(callback, user_data), + parameters={'Streaming': False}) + + result = user_data._completed_requests.get() + print( + f'Output text: {result.as_numpy("text_output")[0].decode("utf-8")}') + + # Cancel request after it is completed + infer_response.cancel() + + # Send stop request for completed request + triton_client.async_infer("tensorrt_llm", + stop_inputs, + request_id=str(request_id), + callback=partial(callback, user_data), + parameters={'Streaming': False}) + + cancel_result = user_data._completed_requests.get() + assert isinstance(cancel_result, InferenceServerException) + assert cancel_result.status() == "StatusCode.CANCELLED" + + # Send a second request to check if server is still healthy + infer_response_2 = triton_client.async_infer( + "tensorrt_llm", + inputs, + request_id=str(request_id + 1), + callback=partial(callback, user_data), + parameters={'Streaming': False}) + + # Get result of second request + result_2 = user_data._completed_requests.get() + print('Got completed request') + + print( + f'Output text: {result_2.as_numpy("text_output")[0].decode("utf-8")}' + ) + + # Check that both results match + assert np.array_equal(result.as_numpy("text_output"), + result_2.as_numpy("text_output"))