mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge branch 'main' into fix/internvl_exmaple_1
This commit is contained in:
commit
8ad6e9d69b
@ -180,9 +180,11 @@ def main():
|
||||
env_files = [
|
||||
JENKINS_PROPS_PATH,
|
||||
DEV_CONTAINER_ENV_PATH,
|
||||
DEV_CONTAINER_USER_ENV_PATH,
|
||||
]
|
||||
|
||||
if DEV_CONTAINER_USER_ENV_PATH.exists():
|
||||
env_files.append(DEV_CONTAINER_USER_ENV_PATH)
|
||||
|
||||
env = _load_env(env_files)
|
||||
_handle_rootless(env_inout=env)
|
||||
|
||||
|
||||
3
.gitattributes
vendored
3
.gitattributes
vendored
@ -1,7 +1,8 @@
|
||||
*.a filter=lfs diff=lfs merge=lfs -text
|
||||
*.dll filter=lfs diff=lfs merge=lfs -text
|
||||
*.lib filter=lfs diff=lfs merge=lfs -text
|
||||
*.so filter=lfs diff=lfs merge=lfs -text
|
||||
*.dll filter=lfs diff=lfs merge=lfs -text
|
||||
*.txz filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
triton_backend/tools/gpt/input_data.json filter=lfs diff=lfs merge=lfs -text
|
||||
*cubin.cpp filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -19,6 +19,9 @@
|
||||
/tensorrt_llm/commands/bench.py @NVIDIA/trtllm-bench-reviewers
|
||||
docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
||||
|
||||
## TensorRT-LLM LLM API
|
||||
/tensorrt_llm/llmapi @NVIDIA/trt-llm-llmapi-devs
|
||||
/tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs
|
||||
|
||||
# The rule below requires that any PR modifying public APIs must be approved by at least one member
|
||||
# of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team.
|
||||
|
||||
2
.github/workflows/blossom-ci.yml
vendored
2
.github/workflows/blossom-ci.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
startsWith(github.event.comment.body, '/bot skip --comment') ||
|
||||
startsWith(github.event.comment.body, '/bot reuse-pipeline') ||
|
||||
startsWith(github.event.comment.body, '/bot kill')) && contains(
|
||||
fromJson('["byshiue","chuangz0","funatiq","hypdeb","jdemouth-nvidia","joyang-nv","lowsfer","Tabrizian","yweng0828","Shixiaowei02","MartinMarciniszyn","schetlur-nv","dcampora","pcastonguay","Naveassaf","lfr-0531","nekorobov","PerkzZheng","kaiyux","nv-guomingz","LinPoly","thorjohnsen","jiahanc","latency1024","tburt-nv","zeroepoch","chzblych","niukuo","ZhanruiSunCh","EmmaQiaoCh","yiqingy0","achartier","suyoggupta","amukkara","mk-nvidia","QiJune","lucaslie","davidmlw","hlu1","nvzhou","syuoni","NVGaryJi","symphonylyh","hello-11","zongfeijing","Jackch-NV","jinyangyuan-nvidia","LarryXFly","crazydemo","jaedeok-nvidia","wm2012011492","rosenrodt","zhuoyao1012","xinhe-nv","Yuening-wa","Shunkangz","zhengd-nv","yibinl-nvidia","StanleySun639","KingsleyLiu-NV","kxdc","yingcanw","BestJuly","ChristinaZ","bobboli","xueweilnvidia","kunlunl","cherichy","lucifer1004","Autumn1998","litaotju","peaceh-nv","liji-nv","SimengLiu-nv","yuxianq","yechank-nvidia","vallis-neria","DylanChen-NV","Tracin","zhhuang-nv","ISEEKYAN","xupinjie","tongyuantongyu","laikhtewari","zhuolingwang","dominicshanshan","jershi425","shifangx","StudyingShao","Superjomn","dongjiyingdjy","guangyunh-nv","wili-65535","tiffany940107","DanBlanaru","mikeiovine","djns99","ruodil","xiaoweiw-nv","xuwchen","bashimao","yizhang-nv","hyukn","nvpohanh","yuki-666","juney-nvidia","barry-delaney","Kefeng-Duan","MinaHuai","yilin-void","jhaotingc","jmydurant","katec846","CarstyYou","Njuapp","Jie-Fang","nvbrantz","inocsin","ruoqianguo","chenfeiz0326","ming-wei","eopXD","longlee0622","dongfengy","georgeliu95","evezhier","rakib-hasan","shangz-ai","JyChang012","wangsiping1997","yuanjings-nvda","tomeras91","roikoren755","amirkl94","shaharmor98","danielafrimi","amitz-nv","hijkzzz","rzilberstein-nvidia","dc3671","hchings","yuhengxnv","dongxuy04","qiaoxj07","omera-nv","DomBrown","brb-nv","FrankD412","yuhsuan-t","Fridah-nv","a-mccarthy","HuiGao-NV","alexmsettle","meenchen","sugunav14","cjluo-nv","kyleliang-nv","chang-l","WeiHaocheng","qixiang-99","BatshevaBlack","ebarilanM","xmchen1987","lingjiew","heyuhhh","netanel-haber","jiefangz-nv","wyw1267","yunruis","sklevtsov-nvidia","jgangani","pamelap-nvidia","ixlmar","GalSha","Dido0o0","rabiel","nvzhihanj","milesial","fzmu727","zackyoray","RoeyAzran1992","viraatc","v-shobhit","yuanjingx87","uchihatmtkinu","nvrohanv","vegaluisjose","qsang-nv","ChunhuanLin","timlee0212","venkywonka","zbpatel","tijyojwad","shyeh25","zihaok","nv-yilinf","ttyio","farazkh80","yuantailing","JennyLiu-nv","moraxu","IzzyPutterman","nvchenghaoz","nvxuanyuc","poweiw","stnie","zhanga5","nzmora-nvidia","greg-kwasniewski1","linda-stadter","Tom-Zheng","vanshilshah97","ixlmar","MatthiasKohl","Wanli-Jiang", "arekay", "davidclark-nv", "2ez4bz", "tcherckez-nvidia", "MrGeva", "galagam", "limin2021", "dhansen-nvidia","talorabr","kanghui0204","wu6u3tw","hvagadia","xavier-nvidia","raayandhar"]'),
|
||||
fromJson('["byshiue","chuangz0","funatiq","hypdeb","jdemouth-nvidia","joyang-nv","lowsfer","Tabrizian","yweng0828","Shixiaowei02","MartinMarciniszyn","schetlur-nv","dcampora","pcastonguay","Naveassaf","lfr-0531","nekorobov","PerkzZheng","kaiyux","nv-guomingz","LinPoly","thorjohnsen","jiahanc","latency1024","tburt-nv","zeroepoch","chzblych","niukuo","ZhanruiSunCh","EmmaQiaoCh","yiqingy0","achartier","suyoggupta","amukkara","mk-nvidia","QiJune","lucaslie","davidmlw","hlu1","nvzhou","syuoni","NVGaryJi","symphonylyh","hello-11","zongfeijing","Jackch-NV","jinyangyuan-nvidia","LarryXFly","crazydemo","jaedeok-nvidia","wm2012011492","rosenrodt","zhuoyao1012","xinhe-nv","Yuening-wa","Shunkangz","zhengd-nv","yibinl-nvidia","StanleySun639","KingsleyLiu-NV","kxdc","yingcanw","BestJuly","ChristinaZ","bobboli","xueweilnvidia","kunlunl","cherichy","lucifer1004","Autumn1998","litaotju","peaceh-nv","liji-nv","SimengLiu-nv","yuxianq","yechank-nvidia","vallis-neria","DylanChen-NV","Tracin","zhhuang-nv","ISEEKYAN","xupinjie","tongyuantongyu","laikhtewari","zhuolingwang","dominicshanshan","jershi425","shifangx","StudyingShao","Superjomn","dongjiyingdjy","guangyunh-nv","wili-65535","tiffany940107","DanBlanaru","mikeiovine","djns99","ruodil","xiaoweiw-nv","xuwchen","bashimao","yizhang-nv","hyukn","nvpohanh","yuki-666","juney-nvidia","barry-delaney","Kefeng-Duan","MinaHuai","yilin-void","jhaotingc","jmydurant","katec846","CarstyYou","Njuapp","Jie-Fang","nvbrantz","inocsin","ruoqianguo","chenfeiz0326","ming-wei","eopXD","longlee0622","dongfengy","georgeliu95","evezhier","rakib-hasan","shangz-ai","JyChang012","wangsiping1997","yuanjings-nvda","tomeras91","roikoren755","amirkl94","shaharmor98","danielafrimi","amitz-nv","hijkzzz","rzilberstein-nvidia","dc3671","hchings","yuhengxnv","dongxuy04","qiaoxj07","omera-nv","DomBrown","brb-nv","FrankD412","yuhsuan-t","Fridah-nv","a-mccarthy","HuiGao-NV","alexmsettle","meenchen","sugunav14","cjluo-nv","kyleliang-nv","chang-l","WeiHaocheng","qixiang-99","BatshevaBlack","ebarilanM","xmchen1987","lingjiew","heyuhhh","netanel-haber","jiefangz-nv","wyw1267","yunruis","sklevtsov-nvidia","jgangani","pamelap-nvidia","ixlmar","GalSha","Dido0o0","rabiel","nvzhihanj","milesial","fzmu727","zackyoray","RoeyAzran1992","viraatc","v-shobhit","yuanjingx87","uchihatmtkinu","nvrohanv","vegaluisjose","qsang-nv","ChunhuanLin","timlee0212","venkywonka","zbpatel","tijyojwad","shyeh25","zihaok","nv-yilinf","ttyio","farazkh80","yuantailing","JennyLiu-nv","moraxu","IzzyPutterman","nvchenghaoz","nvxuanyuc","poweiw","stnie","zhanga5","nzmora-nvidia","greg-kwasniewski1","linda-stadter","Tom-Zheng","vanshilshah97","ixlmar","MatthiasKohl","Wanli-Jiang", "arekay", "davidclark-nv", "2ez4bz", "tcherckez-nvidia", "MrGeva", "galagam", "limin2021", "dhansen-nvidia","talorabr","kanghui0204","wu6u3tw","hvagadia","xavier-nvidia","raayandhar","dbari","nvjullin","elvischenv","zhenhuaw-me","weireweire","yifeizhang-c","jiaganc","ziyixiong-nv","FelixXidddd","JunyiXu-nv","bo-nv","zerollzeng","RayenTian","ameynaik-hub"]'),
|
||||
github.actor)
|
||||
steps:
|
||||
- name: Check if comment is issued by authorized person
|
||||
|
||||
1
.github/workflows/label_community_pr.yml
vendored
1
.github/workflows/label_community_pr.yml
vendored
@ -14,6 +14,7 @@ on:
|
||||
jobs:
|
||||
label_pr:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'NVIDIA/TensorRT-LLM'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -40,6 +40,9 @@ tensorrt_llm/libs
|
||||
tensorrt_llm/bindings.*.so
|
||||
tensorrt_llm/bindings.pyi
|
||||
tensorrt_llm/bindings/**/*.pyi
|
||||
tensorrt_llm/deep_ep/
|
||||
tensorrt_llm/deep_ep_cpp_tllm.*.so
|
||||
tensorrt_llm/deep_ep_cpp_tllm.pyi
|
||||
*docs/cpp_docs*
|
||||
*docs/source/_cpp_gen*
|
||||
docs/source/**/*.rst
|
||||
@ -55,6 +58,7 @@ llm-test-workspace/
|
||||
*.safetensors
|
||||
*/tllm_debug/**
|
||||
*.patch
|
||||
!cpp/tensorrt_llm/deep_ep/*.patch
|
||||
|
||||
# Generated files
|
||||
cpp/include/tensorrt_llm/executor/version.h
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -20,3 +20,6 @@
|
||||
[submodule "3rdparty/xgrammar"]
|
||||
path = 3rdparty/xgrammar
|
||||
url = https://github.com/mlc-ai/xgrammar.git
|
||||
[submodule "3rdparty/nanobind"]
|
||||
path = 3rdparty/nanobind
|
||||
url = https://github.com/wjakob/nanobind
|
||||
|
||||
@ -27,6 +27,7 @@ repos:
|
||||
args: [--allow-multiple-documents]
|
||||
exclude: ".*/gitlab/.*.yml"
|
||||
- id: trailing-whitespace
|
||||
exclude: '\.patch$'
|
||||
- id: check-toml
|
||||
- id: mixed-line-ending
|
||||
args: [--fix=lf]
|
||||
|
||||
1
3rdparty/nanobind
vendored
Submodule
1
3rdparty/nanobind
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit a0ed2587f1089ef7657e2ed49ad6756b01c74e9f
|
||||
33
README.md
33
README.md
@ -9,7 +9,7 @@ TensorRT-LLM
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/torch/arch_overview.md) | [Performance](./docs/source/performance/perf-overview.md) | [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) | [Documentation](./docs/source/) | [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
|
||||
@ -61,22 +61,23 @@ TensorRT-LLM
|
||||
* [02/12] 🌟 How Scaling Laws Drive Smarter, More Powerful AI
|
||||
[➡️ link](https://blogs.nvidia.com/blog/ai-scaling-laws/?ncid=so-link-889273&linkId=100000338837832)
|
||||
|
||||
* [01/25] Nvidia moves AI focus to inference cost, efficiency [➡️ link](https://www.fierceelectronics.com/ai/nvidia-moves-ai-focus-inference-cost-efficiency?linkId=100000332985606)
|
||||
|
||||
* [01/24] 🏎️ Optimize AI Inference Performance with NVIDIA Full-Stack Solutions [➡️ link](https://developer.nvidia.com/blog/optimize-ai-inference-performance-with-nvidia-full-stack-solutions/?ncid=so-twit-400810&linkId=100000332621049)
|
||||
|
||||
* [01/23] 🚀 Fast, Low-Cost Inference Offers Key to Profitable AI [➡️ link](https://blogs.nvidia.com/blog/ai-inference-platform/?ncid=so-twit-693236-vt04&linkId=100000332307804)
|
||||
|
||||
* [01/16] Introducing New KV Cache Reuse Optimizations in TensorRT-LLM [➡️ link](https://developer.nvidia.com/blog/introducing-new-kv-cache-reuse-optimizations-in-nvidia-tensorrt-llm/?ncid=so-twit-363876&linkId=100000330323229)
|
||||
|
||||
* [01/14] 📣 Bing's Transition to LLM/SLM Models: Optimizing Search with TensorRT-LLM [➡️ link](https://blogs.bing.com/search-quality-insights/December-2024/Bing-s-Transition-to-LLM-SLM-Models-Optimizing-Search-with-TensorRT-LLM)
|
||||
|
||||
* [01/04] ⚡Boost Llama 3.3 70B Inference Throughput 3x with TensorRT-LLM Speculative Decoding
|
||||
[➡️ link](https://developer.nvidia.com/blog/boost-llama-3-3-70b-inference-throughput-3x-with-nvidia-tensorrt-llm-speculative-decoding/)
|
||||
|
||||
<details close>
|
||||
<summary>Previous News</summary>
|
||||
|
||||
* [2025/01/25] Nvidia moves AI focus to inference cost, efficiency [➡️ link](https://www.fierceelectronics.com/ai/nvidia-moves-ai-focus-inference-cost-efficiency?linkId=100000332985606)
|
||||
|
||||
* [2025/01/24] 🏎️ Optimize AI Inference Performance with NVIDIA Full-Stack Solutions [➡️ link](https://developer.nvidia.com/blog/optimize-ai-inference-performance-with-nvidia-full-stack-solutions/?ncid=so-twit-400810&linkId=100000332621049)
|
||||
|
||||
* [2025/01/23] 🚀 Fast, Low-Cost Inference Offers Key to Profitable AI [➡️ link](https://blogs.nvidia.com/blog/ai-inference-platform/?ncid=so-twit-693236-vt04&linkId=100000332307804)
|
||||
|
||||
* [2025/01/16] Introducing New KV Cache Reuse Optimizations in TensorRT-LLM [➡️ link](https://developer.nvidia.com/blog/introducing-new-kv-cache-reuse-optimizations-in-nvidia-tensorrt-llm/?ncid=so-twit-363876&linkId=100000330323229)
|
||||
|
||||
* [2025/01/14] 📣 Bing's Transition to LLM/SLM Models: Optimizing Search with TensorRT-LLM [➡️ link](https://blogs.bing.com/search-quality-insights/December-2024/Bing-s-Transition-to-LLM-SLM-Models-Optimizing-Search-with-TensorRT-LLM)
|
||||
|
||||
* [2025/01/04] ⚡Boost Llama 3.3 70B Inference Throughput 3x with TensorRT-LLM Speculative Decoding
|
||||
[➡️ link](https://developer.nvidia.com/blog/boost-llama-3-3-70b-inference-throughput-3x-with-nvidia-tensorrt-llm-speculative-decoding/)
|
||||
|
||||
* [2024/12/10] ⚡ Llama 3.3 70B from AI at Meta is accelerated by TensorRT-LLM. 🌟 State-of-the-art model on par with Llama 3.1 405B for reasoning, math, instruction following and tool use. Explore the preview
|
||||
[➡️ link](https://build.nvidia.com/meta/llama-3_3-70b-instruct)
|
||||
|
||||
@ -204,11 +205,9 @@ Serverless TensorRT-LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co
|
||||
|
||||
TensorRT-LLM is an open-sourced library for optimizing Large Language Model (LLM) inference. It provides state-of-the-art optimizations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, [FP4](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/), INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ...), speculative decoding, and much more, to perform inference efficiently on NVIDIA GPUs.
|
||||
|
||||
Recently [re-architected with a **PyTorch backend**](https://nvidia.github.io/TensorRT-LLM/torch.html), TensorRT-LLM now combines peak performance with a more flexible and developer-friendly workflow. The original [TensorRT](https://developer.nvidia.com/tensorrt)-based backend remains supported and continues to provide an ahead-of-time compilation path for building highly optimized "[Engines](https://docs.nvidia.com/deeplearning/tensorrt/quick-start-guide/index.html#ecosystem)" for deployment. The PyTorch backend complements this by enabling faster development iteration and rapid experimentation.
|
||||
[Architected on PyTorch](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/torch/arch_overview.md), TensorRT-LLM provides a high-level Python [LLM API](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html#llm-api) that supports a wide range of inference setups - from single-GPU to multi-GPU or multi-node deployments. It includes built-in support for various parallelism strategies and advanced features. The LLM API integrates seamlessly with the broader inference ecosystem, including NVIDIA [Dynamo](https://github.com/ai-dynamo/dynamo) and the [Triton Inference Server](https://github.com/triton-inference-server/server).
|
||||
|
||||
TensorRT-LLM provides a flexible [**LLM API**](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html#llm-api) to simplify model setup and inference across both PyTorch and TensorRT backends. It supports a wide range of inference use cases from a single GPU to multiple nodes with multiple GPUs using [Tensor Parallelism](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism) and/or [Pipeline Parallelism](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#pipeline-parallelism). It also includes a [backend](https://github.com/triton-inference-server/tensorrtllm_backend) for integration with the [NVIDIA Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server).
|
||||
|
||||
Several popular models are pre-defined and can be easily customized or extended using [native PyTorch code](./tensorrt_llm/_torch/models/modeling_deepseekv3.py) (for the PyTorch backend) or a [PyTorch-style Python API](./tensorrt_llm/models/llama/model.py) (for the TensorRT backend).
|
||||
TensorRT-LLM is designed to be modular and easy to modify. Its PyTorch-native architecture allows developers to experiment with the runtime or extend functionality. Several popular models are also pre-defined and can be customized using [native PyTorch code](./tensorrt_llm/_torch/models/modeling_deepseekv3.py), making it easy to adapt the system to specific needs.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
@ -7,3 +7,7 @@ h11>=0.16.0
|
||||
tornado>=6.5.0
|
||||
# WAR against https://github.com/advisories/GHSA-5rjg-fvgr-3xxf
|
||||
setuptools>=78.1.1
|
||||
# WAR against https://github.com/advisories/GHSA-8qvm-5x2c-j2w7
|
||||
protobuf>=4.25.8
|
||||
# WAR against https://github.com/advisories/GHSA-33p9-3p43-82vq
|
||||
jupyter-core>=5.8.1
|
||||
|
||||
@ -28,8 +28,6 @@ project(tensorrt_llm LANGUAGES CXX)
|
||||
|
||||
# Build options
|
||||
option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
|
||||
option(BUILD_PYBIND "Build Python bindings for C++ runtime and batch manager"
|
||||
ON)
|
||||
option(BUILD_TESTS "Build Google tests" ON)
|
||||
option(BUILD_BENCHMARKS "Build benchmarks" ON)
|
||||
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
|
||||
@ -45,6 +43,7 @@ option(ENABLE_MULTI_DEVICE
|
||||
option(ENABLE_UCX "Enable building with UCX (Uniform Communication X) support"
|
||||
ON)
|
||||
option(NVRTC_DYNAMIC_LINKING "Link against the dynamic NVRTC libraries" OFF)
|
||||
option(ENABLE_NVSHMEM "Enable building with NVSHMEM support" OFF)
|
||||
option(USING_OSS_CUTLASS_LOW_LATENCY_GEMM
|
||||
"Using open sourced Cutlass low latency gemm kernel" ON)
|
||||
option(USING_OSS_CUTLASS_FP4_GEMM "Using open sourced Cutlass fp4 gemm kernel"
|
||||
@ -54,6 +53,8 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
|
||||
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
|
||||
"Using open sourced Cutlass AR gemm kernel" ON)
|
||||
|
||||
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
|
||||
|
||||
if(NVTX_DISABLE)
|
||||
add_compile_definitions("NVTX_DISABLE")
|
||||
message(STATUS "NVTX is disabled")
|
||||
@ -65,6 +66,11 @@ endif()
|
||||
add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
|
||||
add_compile_definitions("TLLM_ENABLE_CUDA")
|
||||
|
||||
set(BINDING_TYPE
|
||||
"pybind"
|
||||
CACHE STRING
|
||||
"Binding type of Python bindings for C++ runtime and batch manager")
|
||||
|
||||
set(INTERNAL_CUTLASS_KERNELS_PATH
|
||||
""
|
||||
CACHE
|
||||
@ -171,6 +177,7 @@ message(STATUS "CUDA library status:")
|
||||
message(STATUS " version: ${CUDAToolkit_VERSION}")
|
||||
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
|
||||
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||
message(STATUS "CUDA_NVML_LIB: ${CUDA_NVML_LIB}")
|
||||
|
||||
# Prevent CMake from creating a response file for CUDA compiler, so clangd can
|
||||
# pick up on the includes
|
||||
@ -191,7 +198,14 @@ set(TRT_LIB TensorRT::NvInfer)
|
||||
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
|
||||
|
||||
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
||||
add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11)
|
||||
if(BINDING_TYPE STREQUAL "pybind")
|
||||
add_subdirectory(${3RDPARTY_DIR}/pybind11
|
||||
${CMAKE_CURRENT_BINARY_DIR}/pybind11)
|
||||
endif()
|
||||
if(BINDING_TYPE STREQUAL "nanobind")
|
||||
add_subdirectory(${3RDPARTY_DIR}/nanobind
|
||||
${CMAKE_CURRENT_BINARY_DIR}/nanobind)
|
||||
endif()
|
||||
|
||||
# include as system to suppress warnings
|
||||
include_directories(
|
||||
@ -202,8 +216,13 @@ include_directories(
|
||||
${3RDPARTY_DIR}/cutlass/include
|
||||
${3RDPARTY_DIR}/cutlass/tools/util/include
|
||||
${3RDPARTY_DIR}/NVTX/include
|
||||
${3RDPARTY_DIR}/json/include
|
||||
${3RDPARTY_DIR}/pybind11/include)
|
||||
${3RDPARTY_DIR}/json/include)
|
||||
if(BINDING_TYPE STREQUAL "pybind")
|
||||
include_directories(${3RDPARTY_DIR}/pybind11/include)
|
||||
endif()
|
||||
if(BINDING_TYPE STREQUAL "nanobind")
|
||||
include_directories(${3RDPARTY_DIR}/nanobind/include)
|
||||
endif()
|
||||
|
||||
if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL "11")
|
||||
add_definitions("-DENABLE_BF16")
|
||||
@ -262,9 +281,21 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss ")
|
||||
# note: cmake expr generation $<BOOL:${ENABLE_MULTI_DEVICE}> is a build time
|
||||
# evaluation so hard to debug at cmake time
|
||||
if(ENABLE_MULTI_DEVICE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=1")
|
||||
# Add target definitions for both C++ and CUDA
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=1>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=1>)
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=0")
|
||||
# Add target definitions for both C++ and CUDA
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=0>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=0>)
|
||||
endif()
|
||||
|
||||
if(ENABLE_NVSHMEM)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=1>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=1>)
|
||||
else()
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=0>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
|
||||
endif()
|
||||
|
||||
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
|
||||
|
||||
@ -44,13 +44,13 @@ public:
|
||||
KVCacheEventManager(KVCacheEventManager&& other) = delete;
|
||||
KVCacheEventManager& operator=(KVCacheEventManager&& other) = delete;
|
||||
|
||||
void enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel);
|
||||
void enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel, SizeType32 windowSize);
|
||||
|
||||
void enqueueStoredEvent(std::vector<BlockPtr> const& blocks);
|
||||
void enqueueStoredEvent(std::vector<BlockPtr> const& blocks, SizeType32 windowSize);
|
||||
|
||||
void enqueueRemovedEvent(BlockPtr const& block);
|
||||
void enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize);
|
||||
|
||||
void enqueueUpdatedEvent(executor::KVCacheUpdatedData const& data);
|
||||
void enqueueUpdatedEvent(executor::KVCacheUpdatedData const& data, SizeType32 windowSize);
|
||||
|
||||
// Get events in mEvents. If there are no events, wait for a maximum of `timeout` milliseconds.
|
||||
std::deque<executor::KVCacheEvent> getEvents(std::optional<std::chrono::milliseconds> timeout);
|
||||
|
||||
@ -553,6 +553,8 @@ public:
|
||||
|
||||
void storeBlocksForReuse(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||
|
||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||
|
||||
//! \brief Release blocks of the sequence.
|
||||
void releaseBlocks(GenerationRequest& sequence);
|
||||
|
||||
@ -1092,6 +1094,9 @@ public:
|
||||
//! \brief Store context blocks
|
||||
void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest);
|
||||
|
||||
//! \brief Store newest block for reuse
|
||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||
|
||||
[[nodiscard]] static bool isUseOneMoreBlock(
|
||||
SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
|
||||
{
|
||||
@ -1262,6 +1267,10 @@ public:
|
||||
//! \details These blocks become reusable from next step.
|
||||
virtual void storeContextBlocks(LlmRequest const& llmRequest) = 0;
|
||||
|
||||
//! \brief Store newest block for reuse.
|
||||
//! \details This block become reusable from next step.
|
||||
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
|
||||
|
||||
//! \brief Get the block ids of a request [per beam] **for a given window size block manager**
|
||||
[[nodiscard]] virtual std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
|
||||
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
|
||||
@ -1568,6 +1577,9 @@ public:
|
||||
//! \details These blocks become reusable from next step.
|
||||
void storeContextBlocks(LlmRequest const& llmRequest) override;
|
||||
|
||||
//! \brief Store newest blocks for reuse
|
||||
void storeNewBlock(LlmRequest const& llmRequest) override;
|
||||
|
||||
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override;
|
||||
|
||||
@ -303,7 +303,12 @@ inline int getSMVersion()
|
||||
int sm_minor = 0;
|
||||
check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
return sm_major * 10 + sm_minor;
|
||||
int sm = sm_major * 10 + sm_minor;
|
||||
if (sm == 121)
|
||||
{
|
||||
return 120;
|
||||
}
|
||||
return sm;
|
||||
}
|
||||
|
||||
inline int getDevice()
|
||||
|
||||
@ -1709,12 +1709,14 @@ using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVC
|
||||
struct KVCacheEvent
|
||||
{
|
||||
|
||||
KVCacheEvent(IdType eventId, KVCacheEventData data);
|
||||
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize);
|
||||
|
||||
/// @brief The unique id of this event
|
||||
IdType eventId;
|
||||
/// @brief The data corresponding to this event
|
||||
KVCacheEventData data;
|
||||
/// @brief The sliding window size
|
||||
SizeType32 windowSize;
|
||||
};
|
||||
|
||||
/// @brief Exposes a limited set of KV cache manager functionalities
|
||||
|
||||
@ -182,6 +182,14 @@ public:
|
||||
//! @brief Cache indirection output for beam search.
|
||||
[[nodiscard]] TensorPtr getCacheIndirectionOutput() const;
|
||||
|
||||
//! @brief Get the generation steps for all requests in the batch.
|
||||
//! @returns The generation steps for all requests in the batch.
|
||||
[[nodiscard]] std::optional<std::vector<SizeType32>> const& getGenerationSteps() const;
|
||||
|
||||
//! @brief Set the generation steps for all requests in the batch.
|
||||
//! @param generationSteps The generation steps for all requests in the batch.
|
||||
void setGenerationSteps(std::vector<SizeType32> const& generationSteps);
|
||||
|
||||
//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
|
||||
[[nodiscard]] DecodingInput& getJointDecodingInput() const;
|
||||
|
||||
|
||||
@ -142,24 +142,6 @@ public:
|
||||
|
||||
struct EagleInputs
|
||||
{
|
||||
EagleInputs(TensorConstPtr nextDraftTokens, TensorConstPtr nextDraftLens, TensorConstPtr nextDraftPaths,
|
||||
TensorConstPtr lastDraftTokens, TensorConstPtr lastDraftLens, TensorConstPtr lastDraftPaths,
|
||||
TensorConstPtr acceptedTokens, TensorConstPtr acceptedLens, TensorConstPtr acceptedPathIds,
|
||||
TensorConstPtr chunkedContextNextTokens, TensorConstPtr seqSlots)
|
||||
: nextDraftTokens(std::move(nextDraftTokens))
|
||||
, nextDraftLens(std::move(nextDraftLens))
|
||||
, nextDraftPaths(std::move(nextDraftPaths))
|
||||
, lastDraftTokens(std::move(lastDraftTokens))
|
||||
, lastDraftLens(std::move(lastDraftLens))
|
||||
, lastDraftPaths(std::move(lastDraftPaths))
|
||||
, acceptedTokens(std::move(acceptedTokens))
|
||||
, acceptedLens(std::move(acceptedLens))
|
||||
, acceptedPathIds(std::move(acceptedPathIds))
|
||||
, chunkedContextNextTokens(std::move(chunkedContextNextTokens))
|
||||
, seqSlots(std::move(seqSlots))
|
||||
{
|
||||
}
|
||||
|
||||
TensorConstPtr nextDraftTokens; // [batchSize, maxDecodingDraftTokens]
|
||||
TensorConstPtr nextDraftLens; // [batchSize]
|
||||
TensorConstPtr nextDraftPaths; // [batchSize, maxDecodingTokens, maxPathLen]
|
||||
|
||||
@ -18,9 +18,9 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/eagleBuffers.h"
|
||||
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -72,25 +72,6 @@ public:
|
||||
|
||||
//! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize]
|
||||
std::vector<TensorPtr> batchSlots;
|
||||
//! Filled with slots in request order, [batchSize]
|
||||
TensorPtr batchSlotsRequestOrder;
|
||||
|
||||
//! For Beam Search
|
||||
//! The generation step of each request (for Variable-Beam-Width-Search), [batchSize]
|
||||
std::vector<SizeType32> generationSteps;
|
||||
|
||||
//! For speculative decoding
|
||||
//! Logits of draft
|
||||
//! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
|
||||
std::vector<std::vector<TensorPtr>> predictedDraftLogits;
|
||||
|
||||
//! Explicit draft tokens data
|
||||
std::optional<ExplicitDraftTokensBuffers::EngineOutputs> explicitDraftTokensInputs;
|
||||
std::optional<ExplicitDraftTokensBuffers::EngineInputs> explicitDraftTokensLastInputs;
|
||||
|
||||
//! Eagle data
|
||||
std::optional<EagleBuffers::EngineOutputs> eagleInputs;
|
||||
std::optional<EagleBuffers::Inputs> eagleLastInputs;
|
||||
};
|
||||
|
||||
} // namespace decoder_batch
|
||||
|
||||
@ -3049,11 +3049,13 @@ def get_kernel_traits_code(specs_names):
|
||||
return code
|
||||
|
||||
|
||||
# For now, only hopper head_size 128 kernel uses cubins, and other kernels use cu files.
|
||||
# You should set the condition `use_cubin_header` to false if you have modified the source code of the FMHA kernels on Hopper (sm90) with head_size 128.
|
||||
# For now:
|
||||
# 1. Hopper head_size 128 kernel uses cubins for performance regressions.
|
||||
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
|
||||
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
|
||||
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
|
||||
def use_cubin_header(kspec):
|
||||
return kspec.sm == 90 and kspec.head_size == 128
|
||||
def use_cubin_header(sm, head_size, dtype):
|
||||
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
|
||||
|
||||
|
||||
def get_cubin_header(kernel_traits, specs_names):
|
||||
@ -3062,7 +3064,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
cubins_dict = {}
|
||||
cubin_lens_dict = {}
|
||||
for kspec, fname, lname, kname in specs_names:
|
||||
if generate_cu_trtllm and not use_cubin_header(kspec):
|
||||
if generate_cu_trtllm and not use_cubin_header(
|
||||
kspec.sm, kspec.head_size, kspec.dtype):
|
||||
continue
|
||||
name = fname.replace('.', '_')
|
||||
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
|
||||
@ -3215,7 +3218,7 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
if generate_cu_trtllm:
|
||||
|
||||
def get_lname_from_kname(kname: str) -> str:
|
||||
if use_cubin_header(kspec):
|
||||
if use_cubin_header(int(sm), int(head_size), prec.lower()):
|
||||
return 'nullptr'
|
||||
lname = kname.replace('_kernel', '')
|
||||
mask_types = [
|
||||
@ -3234,7 +3237,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
||||
'''.format(**locals()) if use_cubin_header(kspec) else '''\
|
||||
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
|
||||
prec.lower()) else '''\
|
||||
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
||||
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
|
||||
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||
|
||||
@ -23,7 +23,7 @@ struct Kv_block_array
|
||||
{
|
||||
using PtrType = int32_t;
|
||||
|
||||
// Current number of sequences
|
||||
// Maximum number of sequences supported by the kv-cache.
|
||||
int32_t mMaxSeqs;
|
||||
// Max number of blocks per sequence
|
||||
int32_t mMaxBlocksPerSeq;
|
||||
|
||||
@ -44,8 +44,9 @@ function(add_benchmark test_name test_src)
|
||||
benchmark::benchmark)
|
||||
|
||||
target_compile_features(${test_name} PRIVATE cxx_std_17)
|
||||
target_compile_definitions(${test_name}
|
||||
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
|
||||
target_compile_definitions(
|
||||
${test_name} PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}"
|
||||
USING_OSS_CUTLASS_MOE_GEMM)
|
||||
|
||||
add_dependencies(micro_benchmarks ${test_name})
|
||||
endfunction()
|
||||
|
||||
@ -222,7 +222,30 @@ struct UniformRoutingConfig : public RoutingConfig
|
||||
{
|
||||
std::uniform_int_distribution<int> dist(0, num_experts - 1);
|
||||
std::vector<int> input(k * num_tokens);
|
||||
std::generate(input.begin(), input.end(), [&] { return dist(twister); });
|
||||
for (int i = 0; i < num_tokens; i++)
|
||||
{
|
||||
for (int j = 0; j < k; j++)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
int expert_id = dist(twister);
|
||||
bool valid = true;
|
||||
for (int prev_j = 0; prev_j < j; prev_j++)
|
||||
{
|
||||
if (expert_id == input[i * k + prev_j])
|
||||
{
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (valid)
|
||||
{
|
||||
input[i * k + j] = expert_id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
check_cuda_error(cudaMemcpyAsync(
|
||||
selected_experts, input.data(), input.size() * sizeof(int), cudaMemcpyHostToDevice, streamPtr->get()));
|
||||
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
|
||||
@ -322,9 +345,8 @@ public:
|
||||
constexpr static int WEIGHT_ELEM_PER_BYTE = (INT4 || ANY_FP4) ? 2 : 1;
|
||||
int const BASE_HIDDEN_SIZE = 64 / sizeof(WeightType) * WEIGHT_ELEM_PER_BYTE;
|
||||
|
||||
constexpr static int64_t FP4_VECTOR_SIZE = NVFP4
|
||||
? tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
|
||||
: tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize;
|
||||
constexpr static int64_t FP4_VECTOR_SIZE = NVFP4 ? TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
|
||||
: TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize;
|
||||
|
||||
std::vector<BufferManager::IBufferPtr> managed_buffers;
|
||||
int* mSelectedExperts{};
|
||||
@ -476,7 +498,7 @@ public:
|
||||
float* mExpertFP8Scale3{};
|
||||
|
||||
float* mExpertFP4ActScale1{};
|
||||
using ElementSF = tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::ElementSF;
|
||||
using ElementSF = TmaWarpSpecializedGroupedGemmInput::ElementSF;
|
||||
ElementSF* mExpertFP4WeightSf1{};
|
||||
float* mExpertFP4GlobalScale1{};
|
||||
float* mExpertFP4ActScale2{};
|
||||
@ -532,7 +554,7 @@ public:
|
||||
mInterSize = inter_size / parallelism_config.tp_size;
|
||||
mNumExperts = num_experts;
|
||||
mK = k;
|
||||
mIsGated = tensorrt_llm::isGatedActivation(mActType);
|
||||
mIsGated = isGatedActivation(mActType);
|
||||
mGatedMultiplier = mIsGated ? 2 : 1;
|
||||
auto const gated_inter = mInterSize * mGatedMultiplier;
|
||||
|
||||
@ -811,7 +833,7 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
|
||||
int const num_tokens = state.range(7);
|
||||
mUseBias = state.range(8);
|
||||
mUseFinalScale = state.range(9);
|
||||
mActType = static_cast<tensorrt_llm::ActivationType>(state.range(10));
|
||||
mActType = static_cast<ActivationType>(state.range(10));
|
||||
int tactic_idx1 = state.range(11);
|
||||
int tactic_idx2 = state.range(12);
|
||||
int const routing_config = state.range(13);
|
||||
|
||||
@ -472,16 +472,16 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
|
||||
if (!has_tactic_ids2)
|
||||
t2 = t1;
|
||||
|
||||
benchmark->Args({num_experts, //
|
||||
get_range("k"), //
|
||||
get_range("hidden_size"), //
|
||||
get_range("inter_size"), //
|
||||
tp_size, ep_size, world_rank, //
|
||||
get_range("num_tokens"), //
|
||||
bias, do_final_scale, //
|
||||
get_range("act_fn", 0, (int) tensorrt_llm::ActivationType::Identity), //
|
||||
t1, //
|
||||
t2, //
|
||||
benchmark->Args({num_experts, //
|
||||
get_range("k"), //
|
||||
get_range("hidden_size"), //
|
||||
get_range("inter_size"), //
|
||||
tp_size, ep_size, world_rank, //
|
||||
get_range("num_tokens"), //
|
||||
bias, do_final_scale, //
|
||||
get_range("act_fn", 0, (int) ActivationType::Identity), //
|
||||
t1, //
|
||||
t2, //
|
||||
*routing_config});
|
||||
}
|
||||
}
|
||||
@ -497,10 +497,10 @@ void argGenHardcoded(benchmark::internal::Benchmark* benchmark)
|
||||
auto inter_size_mul = {4.f}; // {7.f/2.f, 4.f};
|
||||
auto num_tokens = {2048}; // {1, 20, 200, 2048};
|
||||
auto use_bias = {0}; // {0, 1};
|
||||
auto activation_type = {tensorrt_llm::ActivationType::Gelu};
|
||||
// {tensorrt_llm::ActivationType::Relu, tensorrt_llm::ActivationType::Gelu,
|
||||
// tensorrt_llm::ActivationType::Silu, tensorrt_llm::ActivationType::Geglu,
|
||||
// tensorrt_llm::ActivationType::Swiglu};
|
||||
auto activation_type = {ActivationType::Gelu};
|
||||
// {ActivationType::Relu, ActivationType::Gelu,
|
||||
// ActivationType::Silu, ActivationType::Geglu,
|
||||
// ActivationType::Swiglu};
|
||||
auto cutlass_tactic = {-1}; // {0,..., listAllTactics<BenchClass>().size()};
|
||||
auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2};
|
||||
|
||||
@ -518,7 +518,7 @@ void argGenHardcoded(benchmark::internal::Benchmark* benchmark)
|
||||
for (auto tactic2 : cutlass_tactic)
|
||||
for (auto routing : routing_config)
|
||||
benchmark->Args({num_expert, k, size, inter_size, 1, 1, 0, tokens, bias,
|
||||
(int) act, tactic1, tactic2, routing});
|
||||
1, (int) act, tactic1, tactic2, routing});
|
||||
}
|
||||
}
|
||||
|
||||
@ -540,8 +540,9 @@ void argGen(benchmark::internal::Benchmark* benchmark)
|
||||
|
||||
// Generic setup
|
||||
benchmark->UseManualTime();
|
||||
benchmark->ArgNames({"Num Experts", "K", "Hidden Size", "Inter Size", "TP Size", "EP Size", "World Rank",
|
||||
"Num Tokens", "Use Bias", "Activation Function", "Tactic ID 1", "Tactic ID 2", "Routing ID"});
|
||||
benchmark->ArgNames(
|
||||
{"Num Experts", "K", "Hidden Size", "Inter Size", "TP Size", "EP Size", "World Rank", "Num Tokens", "Use Bias",
|
||||
"Use Final Scale", "Activation Function", "Tactic ID 1", "Tactic ID 2", "Routing ID"});
|
||||
|
||||
if (workloadFile)
|
||||
argGenLoadFile<BenchClass>(benchmark);
|
||||
|
||||
@ -72,6 +72,12 @@ if(ENABLE_MULTI_DEVICE)
|
||||
include_directories(${MPI_C_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
if(ENABLE_NVSHMEM)
|
||||
# Add hints for aarch64
|
||||
find_package(NVSHMEM REQUIRED HINTS /usr/lib/sbsa-linux-gnu/cmake/nvshmem/)
|
||||
include_directories(/usr/include/nvshmem/)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
set(DECODER_SHARED_TARGET_0 decoder_attention_0)
|
||||
set(DECODER_SHARED_TARGET_1 decoder_attention_1)
|
||||
@ -231,7 +237,10 @@ if(ENABLE_MULTI_DEVICE)
|
||||
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} ${MPI_C_LIBRARIES} ${NCCL_LIB})
|
||||
endif()
|
||||
|
||||
message("TRTLLM_LINK_LIBS: ${TRTLLM_LINK_LIBS}")
|
||||
if(ENABLE_NVSHMEM)
|
||||
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} nvshmem::nvshmem_host
|
||||
nvshmem::nvshmem_device)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32) # Unix-like compilers
|
||||
set(UNDEFINED_FLAG "-Wl,--no-undefined")
|
||||
@ -293,8 +302,16 @@ if(BUILD_PYT)
|
||||
add_subdirectory(thop)
|
||||
endif()
|
||||
|
||||
if(BUILD_PYBIND)
|
||||
if(BINDING_TYPE STREQUAL "pybind")
|
||||
add_subdirectory(pybind)
|
||||
endif()
|
||||
|
||||
if(BINDING_TYPE STREQUAL "nanobind")
|
||||
add_subdirectory(nanobind)
|
||||
endif()
|
||||
|
||||
if(BUILD_DEEP_EP)
|
||||
add_subdirectory(deep_ep)
|
||||
endif()
|
||||
|
||||
add_subdirectory(plugins)
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/nvtxUtils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheConcatenate.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
@ -43,8 +43,10 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
|
||||
size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size();
|
||||
constexpr SizeType32 beam{0};
|
||||
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
|
||||
if (common::getEnvDisableSelectiveCacheTransfer())
|
||||
auto poolNum = cacheManager->getBlockManager().getNumPools();
|
||||
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
|
||||
{
|
||||
// disable selective cache transfer for poolNum > 1
|
||||
return blockRange;
|
||||
}
|
||||
if (requestBlockNum < blockRange.size() && requestBlockNum > 0)
|
||||
@ -59,7 +61,9 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
|
||||
|
||||
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest)
|
||||
{
|
||||
if (common::getEnvDisableSelectiveCacheTransfer())
|
||||
|
||||
auto poolNum = cacheManager->getBlockManager().getNumPools();
|
||||
if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer())
|
||||
{
|
||||
constexpr SizeType32 beam{0};
|
||||
return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam);
|
||||
@ -72,7 +76,7 @@ bool CacheFormatter::needSendCache(
|
||||
{
|
||||
// int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
||||
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
||||
if (targetInfo.mDuplicateHeadFactor <= 1)
|
||||
if (targetInfo.mDupHeadFactor <= 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@ -86,7 +90,30 @@ bool CacheFormatter::needSendCache(
|
||||
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
|
||||
}
|
||||
|
||||
return selfTpRankInDpGroup % targetInfo.mDuplicateHeadFactor == 0;
|
||||
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
|
||||
}
|
||||
|
||||
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
|
||||
BaseCacheFormatter::CacheState const& destConfig)
|
||||
{
|
||||
auto numPools = cacheManager->getBlockManager().getNumPools();
|
||||
auto layerNum = cacheManager->getBlockManager().getNumLayers();
|
||||
|
||||
std::vector<SizeType32> poolIdxs(numPools);
|
||||
TLLM_CHECK(layerNum >= numPools);
|
||||
for (int i = 0; i < numPools; i++)
|
||||
{
|
||||
poolIdxs[i] = cacheManager->getBlockManager().getLayerPoolIdx(i);
|
||||
TLLM_LOG_DEBUG("poolIdxs[%d] = %d layerNum:%d", i, poolIdxs[i], layerNum);
|
||||
}
|
||||
|
||||
std::unordered_set<SizeType32> uniquePoolIdxs(poolIdxs.begin(), poolIdxs.end());
|
||||
TLLM_CHECK_WITH_INFO(uniquePoolIdxs.size() == poolIdxs.size(), "poolIdxs must contain unique elements");
|
||||
for (int i = numPools; i < layerNum; i++)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(poolIdxs[i % numPools] == cacheManager->getBlockManager().getLayerPoolIdx(i),
|
||||
"only support Alternate Window");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<executor::kv_cache::Connection const*> CacheFormatter::pickRecvConnections(
|
||||
@ -94,7 +121,7 @@ std::vector<executor::kv_cache::Connection const*> CacheFormatter::pickRecvConne
|
||||
SizeType32 selfIdx, CacheState const& destConfig) const
|
||||
{
|
||||
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
||||
if (targetInfo.mPeerDuplicateHeadFactor <= 1)
|
||||
if (targetInfo.mPeerDupHeadFactor <= 1)
|
||||
{
|
||||
return connections;
|
||||
}
|
||||
@ -103,7 +130,7 @@ std::vector<executor::kv_cache::Connection const*> CacheFormatter::pickRecvConne
|
||||
std::vector<executor::kv_cache::Connection const*> ret;
|
||||
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
|
||||
{
|
||||
if (i % targetInfo.mPeerDuplicateHeadFactor == 0)
|
||||
if (i % targetInfo.mPeerDupHeadFactor == 0)
|
||||
{
|
||||
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
|
||||
{
|
||||
@ -170,14 +197,38 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
|
||||
else
|
||||
{
|
||||
int blockNum = 0;
|
||||
std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks;
|
||||
|
||||
size_t allCacheBlockSize = 0;
|
||||
|
||||
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocks;
|
||||
for (auto poolIdx = 0; poolIdx < numPools; poolIdx++)
|
||||
{
|
||||
blockRange.updatePoolIdx(poolIdx);
|
||||
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx);
|
||||
TLLM_CHECK_WITH_INFO(inputKvCacheBlocks.find(window) == inputKvCacheBlocks.end(),
|
||||
"window size already exists, which is not supported");
|
||||
inputKvCacheBlocks.emplace(window, std::vector<runtime::ITensor::SharedPtr>());
|
||||
auto maxBlockThisWindow = window / selfConfig.getModelConfig().mTokensPerBlock;
|
||||
SizeType32 blockNumThisWindow = 0;
|
||||
for (auto it = blockRange.begin(); it != blockRange.end(); ++it)
|
||||
{
|
||||
blockNum++;
|
||||
inputKvCacheBlocks.push_back(it);
|
||||
inputKvCacheBlocks.at(window).push_back(it);
|
||||
allCacheBlockSize += it->getSize();
|
||||
blockNumThisWindow++;
|
||||
if (blockNumThisWindow >= maxBlockThisWindow)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inputKvCacheBlocks.size() > 1)
|
||||
{
|
||||
if (selfConfig.getParallelConfig().mPipelineParallelism
|
||||
!= destConfig.getParallelConfig().mPipelineParallelism)
|
||||
{
|
||||
checkAlternateWindow(mCacheManager, selfConfig, destConfig);
|
||||
}
|
||||
}
|
||||
TLLM_CHECK(!inputKvCacheBlocks.empty());
|
||||
@ -198,9 +249,12 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||
for (auto const& connection : connections)
|
||||
{
|
||||
for (auto const& block : inputKvCacheBlocks)
|
||||
for (auto const& [window, blocks] : inputKvCacheBlocks)
|
||||
{
|
||||
TransferHelper::sendBuffer(*connection, *block, llmRequest.mRequestId);
|
||||
for (auto const& block : blocks)
|
||||
{
|
||||
TransferHelper::sendBuffer(*connection, *block, llmRequest.mRequestId);
|
||||
}
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.",
|
||||
@ -209,16 +263,13 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
|
||||
return;
|
||||
}
|
||||
|
||||
auto cacheBlockSize = inputKvCacheBlocks.front()->getSize();
|
||||
|
||||
auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
|
||||
int peerDuplicateHeadFactor = targetInfo.mPeerDuplicateHeadFactor;
|
||||
int peerDuplicateHeadFactor = targetInfo.mPeerDupHeadFactor;
|
||||
auto targetNum = connections.size();
|
||||
TLLM_CHECK((cacheBlockSize * blockNum) % targetNum == 0);
|
||||
auto const targetBufferSize = (cacheBlockSize * blockNum) / targetNum * peerDuplicateHeadFactor;
|
||||
auto const targetBufferSize = allCacheBlockSize / targetNum * peerDuplicateHeadFactor;
|
||||
auto bufferTargetNum = targetNum / peerDuplicateHeadFactor;
|
||||
TLLM_LOG_DEBUG(" formatOutput bufferTargetNum: %d, targetNum: %d, peerDuplicateHeadFactor: %d dupliacete:%d ",
|
||||
bufferTargetNum, targetNum, peerDuplicateHeadFactor, targetInfo.mDuplicateHeadFactor);
|
||||
bufferTargetNum, targetNum, peerDuplicateHeadFactor, targetInfo.mDupHeadFactor);
|
||||
|
||||
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
|
||||
cacheBufferId, bufferTargetNum, targetBufferSize, bufferManager);
|
||||
@ -240,7 +291,7 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
|
||||
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
|
||||
if (preAllocSendBuffer != nullptr)
|
||||
{
|
||||
TLLM_CHECK(preAllocSendBuffer->getDataType() == inputKvCacheBlocks.front()->getDataType());
|
||||
TLLM_CHECK(preAllocSendBuffer->getDataType() == inputKvCacheBlocks.begin()->second.front()->getDataType());
|
||||
}
|
||||
auto sendBufferFun = [&](int deviceId, size_t processIdx)
|
||||
{
|
||||
@ -365,20 +416,40 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
|
||||
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
|
||||
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
|
||||
std::vector<runtime::ITensor::SharedPtr> outputBuffers;
|
||||
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputBuffersPerWindow;
|
||||
auto const numPools = mCacheManager->getBlockManager().getNumPools();
|
||||
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
|
||||
size_t blockNum = 0;
|
||||
size_t cacheBlockSizeSum = 0;
|
||||
for (auto poolIdx = 0; poolIdx < numPools; poolIdx++)
|
||||
{
|
||||
blockRange.updatePoolIdx(poolIdx);
|
||||
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx);
|
||||
TLLM_CHECK_WITH_INFO(outputBuffersPerWindow.find(window) == outputBuffersPerWindow.end(),
|
||||
"window size already exists, which is not supported");
|
||||
outputBuffersPerWindow.emplace(window, std::vector<runtime::ITensor::SharedPtr>());
|
||||
auto maxBlockThisWindow = window / selfConfig.getModelConfig().mTokensPerBlock;
|
||||
SizeType32 blockNumThisWindow = 0;
|
||||
for (auto it = blockRange.begin(); it != blockRange.end(); ++it)
|
||||
{
|
||||
blockNum++;
|
||||
outputBuffers.push_back(it);
|
||||
blockNumThisWindow++;
|
||||
outputBuffersPerWindow.at(window).push_back(it);
|
||||
cacheBlockSizeSum += it->getSize();
|
||||
if (blockNumThisWindow >= maxBlockThisWindow)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
TLLM_CHECK(!outputBuffersPerWindow.empty());
|
||||
if (outputBuffersPerWindow.size() > 1)
|
||||
{
|
||||
if (selfConfig.getParallelConfig().mPipelineParallelism != destConfig.getParallelConfig().mPipelineParallelism)
|
||||
{
|
||||
checkAlternateWindow(mCacheManager, selfConfig, destConfig);
|
||||
}
|
||||
}
|
||||
TLLM_CHECK(!outputBuffers.empty());
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(formatInputRecvBuffer);
|
||||
|
||||
@ -437,9 +508,10 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
}
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(formatInputConcatenate);
|
||||
executor::kv_cache::concatenateKVCacheDispatch(recvBufferTmps.data(), recvBufferTmps.size(),
|
||||
getCounterparts(selfConfig, selfIdx, destConfig), destConfig, outputBuffers.data(),
|
||||
outputBuffers.size(), selfIdx, selfConfig, bufferManager);
|
||||
executor::kv_cache::concatKVCacheDispatch(recvBufferTmps.data(), recvBufferTmps.size(),
|
||||
getCounterparts(selfConfig, selfIdx, destConfig), destConfig,
|
||||
outputBuffersPerWindow.begin()->second.data(), outputBuffersPerWindow.begin()->second.size(),
|
||||
selfIdx, selfConfig, bufferManager);
|
||||
bufferManager.getStream().synchronize();
|
||||
}
|
||||
}
|
||||
@ -458,10 +530,13 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||
for (auto const& connection : pickUpConnections)
|
||||
{
|
||||
for (auto const& block : outputBuffers)
|
||||
for (auto const& [window, blocks] : outputBuffersPerWindow)
|
||||
{
|
||||
llmRequest.updateKvCacheSize((*block).getSizeInBytes());
|
||||
TransferHelper::recvBuffer(*connection, *block, reqId);
|
||||
for (auto const& block : blocks)
|
||||
{
|
||||
llmRequest.updateKvCacheSize((*block).getSizeInBytes());
|
||||
TransferHelper::recvBuffer(*connection, *block, reqId);
|
||||
}
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
@ -480,12 +555,10 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
runtime::ITensor::SharedPtr recvBufferTemp;
|
||||
std::vector<runtime::ITensor::SharedPtr> recvSplitCaches;
|
||||
|
||||
auto cacheBlockSize = outputBuffers.front()->getSize();
|
||||
|
||||
auto dataType = outputBuffers.front()->getDataType();
|
||||
auto dataType = outputBuffersPerWindow.begin()->second.front()->getDataType();
|
||||
auto targetNum = pickUpConnections.size();
|
||||
TLLM_CHECK((cacheBlockSize * blockNum) % targetNum == 0);
|
||||
auto targetBufferSize = (cacheBlockSize * blockNum) / targetNum;
|
||||
TLLM_CHECK(cacheBlockSizeSum % targetNum == 0);
|
||||
auto targetBufferSize = cacheBlockSizeSum / targetNum;
|
||||
|
||||
size_t remainNoCoverTargetNum = 0;
|
||||
size_t bufferCoverTargetNum = 0;
|
||||
@ -494,7 +567,6 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
NVTX3_SCOPED_RANGE(formatInputAllocBuffer);
|
||||
|
||||
TLLM_CHECK(blockNum > 0);
|
||||
TLLM_CHECK(outputBuffers.size() == blockNum);
|
||||
if (legacyPath)
|
||||
{
|
||||
|
||||
@ -662,14 +734,16 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
|
||||
if (legacyPath)
|
||||
{
|
||||
executor::kv_cache::concatenateKVCacheDispatch(recvSplitCaches.data(), recvSplitCaches.size(),
|
||||
getCounterparts(selfConfig, selfIdx, destConfig), destConfig, outputBuffers.data(),
|
||||
outputBuffers.size(), selfIdx, selfConfig, bufferManager);
|
||||
TLLM_CHECK(outputBuffersPerWindow.size() == 1);
|
||||
executor::kv_cache::concatKVCacheDispatch(recvSplitCaches.data(), recvSplitCaches.size(),
|
||||
getCounterparts(selfConfig, selfIdx, destConfig), destConfig,
|
||||
outputBuffersPerWindow.begin()->second.data(), outputBuffersPerWindow.begin()->second.size(),
|
||||
selfIdx, selfConfig, bufferManager);
|
||||
}
|
||||
else
|
||||
{
|
||||
executor::kv_cache::concatenateKvCacheV2Dispatch(
|
||||
recvSplitCaches, outputBuffers, destConfig, selfConfig, selfIdx, bufferManager);
|
||||
executor::kv_cache::concatKvCacheV2Dispatch(
|
||||
recvSplitCaches, outputBuffersPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
|
||||
}
|
||||
bufferManager.getStream().synchronize();
|
||||
if (cacheBufferId.has_value())
|
||||
|
||||
@ -23,8 +23,7 @@
|
||||
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/cacheCommunicator.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheConcatenate.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
@ -194,20 +194,29 @@ CacheTransBufferManager::CacheTransBufferManager(
|
||||
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
||||
{
|
||||
|
||||
// TODO: FP4 dataSize
|
||||
TLLM_CHECK(mCacheManager);
|
||||
mDataType = mCacheManager->getPrimaryPool(0)->getDataType();
|
||||
|
||||
auto tokensPerBlock = mCacheManager->getBlockManager().getTokensPerBlock();
|
||||
size_t bufferSizeFromMaxNumToken = 0;
|
||||
if (maxNumTokens.has_value())
|
||||
{
|
||||
TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0);
|
||||
auto dataSize = common::getDTypeSize(mDataType);
|
||||
auto kvCacheByteSizePerTokenPerLayer = mCacheManager->getBlockManager().getBlockSize(0) / tokensPerBlock
|
||||
* (mCacheManager->getCacheType() == CacheType::kSELFKONLY ? 1 : 2) * dataSize;
|
||||
for (auto layerId = 0; layerId < mCacheManager->getBlockManager().getNumLayers(); layerId++)
|
||||
{
|
||||
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
|
||||
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
|
||||
auto validTokenNum = windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value();
|
||||
bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer;
|
||||
}
|
||||
}
|
||||
auto kvCachePerToken
|
||||
= (mCacheManager->getBlockManager().getBlockSize(0) * mCacheManager->getBlockManager().getNumLayers()
|
||||
* (mCacheManager->getCacheType() == CacheType::kSELFKONLY ? 1 : 2))
|
||||
/ tokensPerBlock;
|
||||
mTransferBufferSize = maxNumTokens.has_value() ? maxNumTokens.value() * kvCachePerToken
|
||||
: common::getEnvMemSizeForKVCacheTransferBuffer();
|
||||
|
||||
mTransferBufferSize
|
||||
= maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
|
||||
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
|
||||
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
|
||||
mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
|
||||
|
||||
@ -517,7 +517,9 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
||||
// Gather the kv cache transfer time from all workers and update to leader rank
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
updateKVCacheTransferBW(*mMpiGroupComm, it->first);
|
||||
auto syncComm
|
||||
= mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupDataComm.get() : mMpiGroupComm;
|
||||
updateKVCacheTransferBW(*syncComm, it->first);
|
||||
}
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
|
||||
|
||||
@ -151,7 +151,9 @@ void DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
|
||||
|
||||
RequestInfo requestInfo(requestId, mSelfState);
|
||||
|
||||
if (!common::getEnvDisableSelectiveCacheTransfer())
|
||||
auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer()
|
||||
|| (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1);
|
||||
if (!disableSelectiveCacheTransfer)
|
||||
{
|
||||
auto* cacheManager = mFormatter->getCacheManager();
|
||||
auto blockRange
|
||||
|
||||
@ -18,7 +18,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "cacheFormatter.h"
|
||||
#include "cacheTransBuffer.h"
|
||||
#include "dataTransceiver.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
@ -42,12 +42,13 @@ KVCacheEventManager::~KVCacheEventManager()
|
||||
mWorkerThread.join();
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel)
|
||||
void KVCacheEventManager::enqueueCreatedEvent(
|
||||
std::vector<SizeType32> const& numBlocksPerCacheLevel, SizeType32 windowSize)
|
||||
{
|
||||
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}});
|
||||
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize});
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks)
|
||||
void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks, SizeType32 windowSize)
|
||||
{
|
||||
if (blocks.empty())
|
||||
{
|
||||
@ -67,24 +68,26 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
|
||||
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority());
|
||||
}
|
||||
|
||||
enqueueEvent({mEventId++, data});
|
||||
enqueueEvent({mEventId++, data, windowSize});
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block)
|
||||
void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize)
|
||||
{
|
||||
if (!mEventQueue.empty() && std::holds_alternative<tle::KVCacheRemovedData>(mEventQueue.back().data))
|
||||
// We can only batch the removed block events if the same sliding window size is used.
|
||||
if (!mEventQueue.empty() && mEventQueue.back().windowSize == windowSize
|
||||
&& std::holds_alternative<tle::KVCacheRemovedData>(mEventQueue.back().data))
|
||||
{
|
||||
std::get<tle::KVCacheRemovedData>(mEventQueue.back().data).blockHashes.push_back(block->getHash());
|
||||
}
|
||||
else
|
||||
{
|
||||
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}});
|
||||
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize});
|
||||
}
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data)
|
||||
void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize)
|
||||
{
|
||||
enqueueEvent({mEventId++, data});
|
||||
enqueueEvent({mEventId++, data, windowSize});
|
||||
}
|
||||
|
||||
void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event)
|
||||
|
||||
@ -552,7 +552,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
|
||||
mAllBlocksById, {blocksInPrimaryPool, blocksInSecondaryPool}, secondaryOffloadMinPriority);
|
||||
if (mEventManager)
|
||||
{
|
||||
mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool});
|
||||
mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool}, mWindowSize);
|
||||
}
|
||||
}
|
||||
|
||||
@ -741,7 +741,7 @@ void WindowBlockManager::freeChildren(
|
||||
// Free block
|
||||
if (mEventManager && blockInRadixTree(block))
|
||||
{
|
||||
mEventManager->enqueueRemovedEvent(block);
|
||||
mEventManager->enqueueRemovedEvent(block, mWindowSize);
|
||||
}
|
||||
|
||||
claimLeafBlock(block, priority, durationMs);
|
||||
@ -776,7 +776,8 @@ BlockPtr WindowBlockManager::getFreeBlock(
|
||||
if (mEventManager && blockInRadixTree(block))
|
||||
{
|
||||
mEventManager->enqueueUpdatedEvent(
|
||||
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel));
|
||||
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel),
|
||||
mWindowSize);
|
||||
}
|
||||
mEvictionPolicy->releaseBlock(block); // append offload block to mFreeSecondaryBlocks queue
|
||||
block = offloadBlock;
|
||||
@ -881,7 +882,8 @@ void WindowBlockManager::onboardBlock(BlockPtr const& offloadBlock)
|
||||
if (mEventManager)
|
||||
{
|
||||
mEventManager->enqueueUpdatedEvent(
|
||||
tle::KVCacheUpdatedData(offloadBlock->getHash()).cacheLevelUpdated(kSecondaryLevel, kPrimaryLevel));
|
||||
tle::KVCacheUpdatedData(offloadBlock->getHash()).cacheLevelUpdated(kSecondaryLevel, kPrimaryLevel),
|
||||
mWindowSize);
|
||||
}
|
||||
mEvictionPolicy->releaseBlock(block); // append block to offload queue
|
||||
// offloadBlock is now in primary memory pool
|
||||
@ -908,7 +910,8 @@ void WindowBlockManager::offloadBlock(BlockPtr const& block)
|
||||
if (mEventManager && blockInRadixTree(block))
|
||||
{
|
||||
mEventManager->enqueueUpdatedEvent(
|
||||
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel));
|
||||
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel),
|
||||
mWindowSize);
|
||||
}
|
||||
mEvictionPolicy->releaseBlock(offloadBlock); // append offloadBlock to mFreePrimaryBlocks queue
|
||||
// block is now in secondary memory
|
||||
@ -980,7 +983,8 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
{
|
||||
mEventManager->enqueueUpdatedEvent(
|
||||
tle::KVCacheUpdatedData(matchingBlock->getHash())
|
||||
.priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority));
|
||||
.priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority),
|
||||
mWindowSize);
|
||||
}
|
||||
if (partialMatch)
|
||||
{
|
||||
@ -1275,7 +1279,7 @@ void WindowBlockManager::storeBlocks(
|
||||
}
|
||||
if (mEventManager)
|
||||
{
|
||||
mEventManager->enqueueStoredEvent(storedBlocks);
|
||||
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1411,6 +1415,77 @@ void BlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmReq
|
||||
}
|
||||
}
|
||||
|
||||
void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
|
||||
{
|
||||
// we store newest block for potential reuse only if:
|
||||
// - Block reuse is enabled.
|
||||
// - A request was provided to this function call to identify which tokens these blocks cover
|
||||
// - Beam search is NOT enabled <=> beam width == 1
|
||||
// - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit
|
||||
// the max attention window).
|
||||
// - The sequence did not switch to cyclic kv-cache during generation phase.
|
||||
// A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached.
|
||||
bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !sequence.isCyclic();
|
||||
if (!storeBlocksForReuse)
|
||||
{
|
||||
return;
|
||||
}
|
||||
for (auto& [_, manager] : mWindowBlockManagers)
|
||||
{
|
||||
manager.storeNewBlock(sequence, llmRequest);
|
||||
}
|
||||
}
|
||||
|
||||
void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
|
||||
{
|
||||
auto constexpr beamIdx = 0;
|
||||
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
|
||||
auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize);
|
||||
|
||||
if (uniqueTokens.size() == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't
|
||||
// have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume
|
||||
// the last token's state is not filled yet.
|
||||
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
|
||||
if (usableSize % mTokensPerBlock != 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
|
||||
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
||||
if (blockKeys.size() < 2 || cacheBlockIds[beamIdx].size() < blockKeys.size())
|
||||
{
|
||||
// store all blocks
|
||||
TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str());
|
||||
storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
||||
return;
|
||||
}
|
||||
|
||||
auto lastBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 1]);
|
||||
auto prevBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 2]);
|
||||
|
||||
// If the previous block is not in the radix tree, we need to store all blocks
|
||||
if (prevBlock->getPrevBlock() == nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str());
|
||||
storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
||||
return;
|
||||
}
|
||||
|
||||
if (lastBlock->getPrevBlock() != nullptr)
|
||||
{
|
||||
// If the last block is not in the radix tree, we need to store all blocks
|
||||
TLLM_LOG_DEBUG("%s::storeNewBlock - no need to store", mLogPrefix.c_str());
|
||||
return;
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s::storeNewBlock - store the last block", mLogPrefix.c_str());
|
||||
storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
||||
}
|
||||
|
||||
void WindowBlockManager::storeBlocksForReuse(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
|
||||
{
|
||||
auto constexpr beamIdx = 0;
|
||||
@ -1960,6 +2035,17 @@ void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest)
|
||||
}
|
||||
}
|
||||
|
||||
void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest)
|
||||
{
|
||||
auto const requestId = llmRequest.mRequestId;
|
||||
auto& sequence = getSequence(requestId);
|
||||
bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && !sequence.isCyclic();
|
||||
if (mEnableBlockReuse && storeBlocksForReuse)
|
||||
{
|
||||
mBlockManager.storeNewBlock(sequence, llmRequest);
|
||||
}
|
||||
}
|
||||
|
||||
void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest)
|
||||
{
|
||||
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -118,6 +118,62 @@ std::pair<std::vector<SizeType32>, std::vector<SizeType32>> getActiveSlots(
|
||||
return {activeSlots, generationSteps};
|
||||
}
|
||||
|
||||
//! @brief Sets inputs for explicit draft tokens.
|
||||
void setExplicitDraftTokensInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(fusedRuntimeBuffers.mExplicitDraftTokensBuffers);
|
||||
auto const& explicitDraftTokensInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineOutputs;
|
||||
auto const& explicitDraftTokensLastInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineInputs;
|
||||
|
||||
dInput.explicitDraftTokensInputs = tr::DecodingInput::ExplicitDraftTokensInputs();
|
||||
dInput.explicitDraftTokensInputs->nextDraftTokens = explicitDraftTokensInputs.nextDraftTokens;
|
||||
dInput.explicitDraftTokensInputs->nextFlatTokens = explicitDraftTokensInputs.nextFlatTokens;
|
||||
dInput.explicitDraftTokensInputs->nextDraftIndices = explicitDraftTokensInputs.nextDraftIndices;
|
||||
dInput.explicitDraftTokensInputs->nextDraftProbs = explicitDraftTokensInputs.nextDraftProbs;
|
||||
dInput.explicitDraftTokensInputs->lastDraftTokens = explicitDraftTokensLastInputs.draftTokens;
|
||||
dInput.explicitDraftTokensInputs->lastDraftIndices = explicitDraftTokensLastInputs.draftIndices;
|
||||
dInput.explicitDraftTokensInputs->lastPositionIdsBase = explicitDraftTokensLastInputs.positionIdsBase;
|
||||
dInput.explicitDraftTokensInputs->masks = explicitDraftTokensInputs.masks;
|
||||
dInput.explicitDraftTokensInputs->packedPositionIds = explicitDraftTokensInputs.packedPositionIds;
|
||||
dInput.explicitDraftTokensInputs->bestPathLengths = explicitDraftTokensInputs.bestPathLengths;
|
||||
dInput.explicitDraftTokensInputs->bestPathIndices = explicitDraftTokensInputs.bestPathIndices;
|
||||
dInput.explicitDraftTokensInputs->nextGenerationLengths = explicitDraftTokensInputs.nextGenerationLengths;
|
||||
dInput.explicitDraftTokensInputs->lastGenerationLengths = explicitDraftTokensLastInputs.generationLengths;
|
||||
dInput.explicitDraftTokensInputs->maxGenLengthDevice = explicitDraftTokensInputs.maxGenToken;
|
||||
// Slots in request order
|
||||
dInput.explicitDraftTokensInputs->seqSlots = fusedRuntimeBuffers.seqSlots;
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
//! @brief Sets inputs for eagle decoding.
|
||||
void setEagleInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(fusedRuntimeBuffers.mEagleBuffers);
|
||||
auto const& eagleInputs = fusedRuntimeBuffers.mEagleBuffers->engineOutputs;
|
||||
auto const& eagleLastInputs = fusedRuntimeBuffers.mEagleBuffers->engineInputs;
|
||||
|
||||
dInput.eagleInputs = tr::DecodingInput::EagleInputs();
|
||||
dInput.eagleInputs->nextDraftTokens = eagleInputs.nextDraftTokens;
|
||||
dInput.eagleInputs->nextDraftLens = eagleInputs.nextDraftLens;
|
||||
dInput.eagleInputs->nextDraftPaths = eagleInputs.nextDraftPaths;
|
||||
dInput.eagleInputs->lastDraftTokens = eagleLastInputs.draftTokens;
|
||||
dInput.eagleInputs->lastDraftLens = eagleLastInputs.draftLens;
|
||||
dInput.eagleInputs->lastDraftPaths = eagleLastInputs.draftPaths;
|
||||
dInput.eagleInputs->acceptedTokens = eagleInputs.acceptedTokens;
|
||||
dInput.eagleInputs->acceptedLens = eagleInputs.acceptedLens;
|
||||
dInput.eagleInputs->acceptedPathIds = eagleInputs.acceptedPaths;
|
||||
dInput.eagleInputs->chunkedContextNextTokens = eagleInputs.chunkedContextNextTokens;
|
||||
// Slots in request order
|
||||
dInput.eagleInputs->seqSlots = fusedRuntimeBuffers.seqSlots;
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests,
|
||||
@ -131,28 +187,30 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator
|
||||
|
||||
auto decodingInput = createDecoderBatchInputs(
|
||||
activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots);
|
||||
decodingInput->generationSteps = generationSteps;
|
||||
|
||||
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
|
||||
if (maxBeamWidth > 1)
|
||||
{
|
||||
// For Variable-Beam-Width-Search
|
||||
decoderState.getJointDecodingInput().generationSteps = generationSteps;
|
||||
}
|
||||
|
||||
if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
|
||||
{
|
||||
decodingInput->predictedDraftLogits = inputBuffers.predictedDraftLogits;
|
||||
decoderState.getJointDecodingInput().medusaInputs->medusaLogits = inputBuffers.predictedDraftLogits;
|
||||
}
|
||||
|
||||
if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
|
||||
{
|
||||
TLLM_CHECK(fusedRuntimeBuffers);
|
||||
// requires mCtxGenFusion == true
|
||||
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
|
||||
decodingInput->explicitDraftTokensInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineOutputs;
|
||||
decodingInput->explicitDraftTokensLastInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineInputs;
|
||||
setExplicitDraftTokensInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers);
|
||||
}
|
||||
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
|
||||
{
|
||||
TLLM_CHECK(fusedRuntimeBuffers);
|
||||
// requires mCtxGenFusion == true
|
||||
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
|
||||
decodingInput->eagleInputs = fusedRuntimeBuffers->mEagleBuffers->engineOutputs;
|
||||
decodingInput->eagleLastInputs = fusedRuntimeBuffers->mEagleBuffers->engineInputs;
|
||||
setEagleInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -18,14 +18,13 @@
|
||||
#include "mlaCacheFormatter.h"
|
||||
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/contextProgress.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/nvtxUtils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheConcatenate.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
@ -170,8 +169,11 @@ void MLACacheFormatter::formatOutput(LlmRequest const& llmRequest,
|
||||
|
||||
// The size of outputSplitCaches should be equal to pPDomainSize
|
||||
|
||||
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(0);
|
||||
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocksPerWindow;
|
||||
inputKvCacheBlocksPerWindow.emplace(window, inputKvCacheBlocks);
|
||||
tensorrt_llm::executor::kv_cache::splitKVCacheDispatch(
|
||||
inputKvCacheBlocks, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
|
||||
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
|
||||
|
||||
bufferManager.getStream().synchronize();
|
||||
|
||||
@ -185,25 +187,28 @@ void MLACacheFormatter::formatOutput(LlmRequest const& llmRequest,
|
||||
NVTX3_SCOPED_RANGE(sendBufferFun);
|
||||
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||
auto startTime = std::chrono::steady_clock::now();
|
||||
auto cacheIdx = processIdx % pPDomainSize;
|
||||
size_t size;
|
||||
if (cacheIdx < bufferCoverTargetNum)
|
||||
{
|
||||
|
||||
size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
|
||||
TransferHelper::sendBuffer(*connections.at(processIdx), *outputSplitCaches.at(cacheIdx), reqId);
|
||||
}
|
||||
else if (bufferCoverTargetNum > 0)
|
||||
{
|
||||
// copy buffer allocated by cudaMallocAsync to buffer allocated by cudaMalloc before sending
|
||||
auto sendBufferIdx = cacheIdx % bufferCoverTargetNum;
|
||||
size = outputSplitCaches.at(sendBufferIdx)->getSizeInBytes();
|
||||
bufferManager.copy(*outputSplitCaches.at(cacheIdx), *outputSplitCaches.at(sendBufferIdx));
|
||||
bufferManager.getStream().synchronize();
|
||||
TransferHelper::sendBuffer(*connections.at(processIdx), *outputSplitCaches.at(sendBufferIdx), reqId);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
// bufferCoverTargetNum=0, mSendBuffer size < one outputSlice
|
||||
// send multiple times
|
||||
size = targetBufferSize;
|
||||
size_t remainSendSize = targetBufferSize;
|
||||
while (remainSendSize > 0)
|
||||
{
|
||||
@ -220,6 +225,10 @@ void MLACacheFormatter::formatOutput(LlmRequest const& llmRequest,
|
||||
remainSendSize -= sendSize;
|
||||
}
|
||||
}
|
||||
auto endTime = std::chrono::steady_clock::now();
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, cacheTransferTime, size);
|
||||
};
|
||||
|
||||
if (connections.size() > 1)
|
||||
@ -446,11 +455,14 @@ void MLACacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
}
|
||||
|
||||
{
|
||||
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputCachesPerWindow;
|
||||
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(0);
|
||||
outputCachesPerWindow.emplace(window, outputBuffers);
|
||||
NVTX3_SCOPED_RANGE(formatInputConcatenate);
|
||||
|
||||
// recvSplitCaches size == ppdomainsize
|
||||
executor::kv_cache::concatenateKvCacheV2Dispatch(
|
||||
recvSplitCaches, outputBuffers, destConfig, selfConfig, selfIdx, bufferManager);
|
||||
executor::kv_cache::concatKvCacheV2Dispatch(
|
||||
recvSplitCaches, outputCachesPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
|
||||
}
|
||||
bufferManager.getStream().synchronize();
|
||||
}
|
||||
|
||||
@ -64,6 +64,7 @@ public:
|
||||
private:
|
||||
BaseKVCacheManager* mCacheManager;
|
||||
CacheTransBufferManager* mCacheTransBufferManager;
|
||||
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -929,6 +929,25 @@ void TrtGptModelInflightBatching::storeContextBlocks(std::shared_ptr<LlmRequest>
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void TrtGptModelInflightBatching::storeNewBlock(std::shared_ptr<LlmRequest> const& llmReq)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
// TMJ - Note
|
||||
// Make context blocks reusable immediately after each generation step.
|
||||
|
||||
if (mKvCacheManager)
|
||||
{
|
||||
mKvCacheManager->storeNewBlock(*llmReq);
|
||||
}
|
||||
if (mCrossKvCacheManager)
|
||||
{
|
||||
mCrossKvCacheManager->storeNewBlock(*llmReq);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void TrtGptModelInflightBatching::resetIterationStats()
|
||||
{
|
||||
mLastIterationStatsIFB = IterationStatsIFB{mMicroBatchId};
|
||||
@ -1099,6 +1118,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
|
||||
}
|
||||
else if (llmReq->isGenerationInProgressState())
|
||||
{
|
||||
storeNewBlock(llmReq);
|
||||
TLLM_LOG_DEBUG("request with ID %lu forwards a step in decoder gen phase", llmReq->mRequestId);
|
||||
}
|
||||
}
|
||||
|
||||
@ -248,6 +248,10 @@ private:
|
||||
//! These blocks become reusable from next step.
|
||||
void storeContextBlocks(std::shared_ptr<LlmRequest> const& req);
|
||||
|
||||
//! @brief Store newest kv cache block for reuse.
|
||||
//! The block become reusable from next step.
|
||||
void storeNewBlock(std::shared_ptr<LlmRequest> const& req);
|
||||
|
||||
//! @brief Set LayerProfiler to collect performance per layer.
|
||||
void setLayerProfiler() override;
|
||||
|
||||
|
||||
@ -2409,22 +2409,22 @@ int AttentionOp::initialize() noexcept
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "FP8 FMHA cannot be enabled because Context FMHA is not supported.");
|
||||
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120,
|
||||
"FP8 FMHA can only be enabled on sm_89, sm_90, sm_100 or sm_120.");
|
||||
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120 || mSM == 121,
|
||||
"FP8 FMHA can only be enabled on sm_89, sm_90, sm_100, sm_120 or sm_121.");
|
||||
}
|
||||
|
||||
// Pre-Check of FP8 Generation MLA.
|
||||
if (mFP8GenerationMLA)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mIsMLAEnabled, "FP8 Generation MLA cannot be enabled because MLA is not supported.");
|
||||
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120,
|
||||
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120 || mSM == 121,
|
||||
"FP8 Generation MLA is supported on Ada, Hopper or Blackwell architecture.");
|
||||
}
|
||||
|
||||
// Check requirements for FP4 output.
|
||||
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mEnableContextFMHA, "Context FMHA must enable if fuse_fp4_quant is enabled");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
!mFuseFp4Quant || mSM == 100 || mSM == 120, "fuse_fp4_quant only supports SM100 or SM120 devices.");
|
||||
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121,
|
||||
"fuse_fp4_quant only supports SM100 or SM120 or SM121 devices.");
|
||||
|
||||
TLLM_CHECK(isRoPE() == (mRotaryEmbeddingDim != 0));
|
||||
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),
|
||||
|
||||
@ -332,12 +332,12 @@ enum class ClusterShape
|
||||
ClusterShape_1x2x1,
|
||||
ClusterShape_2x2x1,
|
||||
ClusterShape_1x4x1,
|
||||
ClusterShape_4x1x1,
|
||||
ClusterShape_4x2x1,
|
||||
ClusterShape_2x4x1,
|
||||
ClusterShape_4x4x1,
|
||||
ClusterShape_1x8x1,
|
||||
ClusterShape_8x1x1,
|
||||
ClusterShape_4x1x1
|
||||
ClusterShape_8x1x1
|
||||
};
|
||||
|
||||
static auto get_cluster_shape_name(ClusterShape Shape_MNK)
|
||||
@ -484,9 +484,9 @@ struct CutlassGemmConfig
|
||||
|
||||
int getTileConfigAsInt() const
|
||||
{
|
||||
if (sm_version == 120)
|
||||
if (sm_version == 120 || sm_version == 121)
|
||||
return (int) tile_config_sm120;
|
||||
if (sm_version >= 100)
|
||||
if (sm_version >= 100 && sm_version < 120)
|
||||
return (int) tile_config_sm100;
|
||||
if (sm_version == 90)
|
||||
return (int) tile_config_sm90;
|
||||
|
||||
@ -22,6 +22,8 @@
|
||||
|
||||
#include "cutlass/barrier.h"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
@ -43,7 +45,7 @@ __forceinline__ __device__ uint32_t atomicCAS_system_acq(uint32_t* p, uint32_t c
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class Sync, bool SafeBetweenPhases, bool UseMembarGPU>
|
||||
template <class Sync, bool SafeBetweenPhases>
|
||||
struct MulticastSystemBarrier : public GenericBarrier<Sync>
|
||||
{
|
||||
|
||||
@ -57,8 +59,8 @@ struct MulticastSystemBarrier : public GenericBarrier<Sync>
|
||||
|
||||
protected:
|
||||
/// Reduce into flag, with release pattern (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static void red_release(T* mc_ptr, int val)
|
||||
template <cuda::thread_scope Scope>
|
||||
CUTLASS_DEVICE static void red_release(T* mc_ptr, int val)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MULTIMEM_SM90_ENABLED)
|
||||
// atomic reduction to all replicas
|
||||
@ -66,14 +68,18 @@ protected:
|
||||
// See
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-multimem-ld-reduce-multimem-st-multimem-red
|
||||
// for multimem PTX doc
|
||||
if constexpr (UseMembarGPU)
|
||||
if constexpr (Scope == cuda::thread_scope::thread_scope_device)
|
||||
{
|
||||
asm volatile("multimem.red.release.gpu.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
|
||||
}
|
||||
else
|
||||
else if constexpr (Scope == cuda::thread_scope::thread_scope_system)
|
||||
{
|
||||
asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTE_INVALID_CONTROL_PATH("Invalid thread scope for MulticastSystemBarrier.");
|
||||
}
|
||||
|
||||
// Need a fence between MC and UC access to the same memory:
|
||||
// - fence.proxy instructions establish an ordering between memory accesses that may happen through different
|
||||
@ -128,8 +134,8 @@ public:
|
||||
Sync::sync();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
template <cuda::thread_scope Scope>
|
||||
CUTLASS_DEVICE static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
{
|
||||
T* mc_barrier_ptr = mc_ptr + flag_idx;
|
||||
T* uc_barrier_ptr = uc_ptr + flag_idx;
|
||||
@ -156,13 +162,13 @@ public:
|
||||
// can be immediately reused.
|
||||
bool master = rank == 0;
|
||||
int val = master ? 0x80000000 - (world_size - 1) : 1;
|
||||
red_release(mc_barrier_ptr, val);
|
||||
red_release<Scope>(mc_barrier_ptr, val);
|
||||
}
|
||||
return old_arrive;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
|
||||
CUTLASS_DEVICE static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
{
|
||||
T* mc_barrier = params.mc_barrier_ptr + flag_idx;
|
||||
|
||||
@ -170,23 +176,24 @@ public:
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
red_release(mc_barrier, 1);
|
||||
red_release<Scope>(mc_barrier, 1);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_wait(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
|
||||
CUTLASS_DEVICE static void arrive_and_wait(
|
||||
Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
{
|
||||
auto mc_ptr = params.mc_barrier_ptr;
|
||||
auto uc_ptr = params.uc_barrier_ptr;
|
||||
if constexpr (SafeBetweenPhases)
|
||||
{
|
||||
auto old_arrive = arrive_inc_get(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
|
||||
auto old_arrive = arrive_inc_get<Scope>(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
|
||||
wait(old_arrive, uc_ptr, thread_idx, flag_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
arrive_inc(params, thread_idx, flag_idx, rank, world_size);
|
||||
arrive_inc<Scope>(params, thread_idx, flag_idx, rank, world_size);
|
||||
wait_eq_reset(uc_ptr, thread_idx, flag_idx, world_size);
|
||||
}
|
||||
}
|
||||
|
||||
207
cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Normal file
207
cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Normal file
@ -0,0 +1,207 @@
|
||||
set(DEEP_EP_COMMIT c381dadf43a85062f6a8947592017ee513abc70b)
|
||||
set(NVSHMEM_URL_HASH
|
||||
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
|
||||
|
||||
add_custom_target(deep_ep)
|
||||
|
||||
# CUDA architectures
|
||||
# ==================
|
||||
|
||||
# Filter CUDA arch >= 9.0
|
||||
set(DEEP_EP_CUDA_ARCHITECTURES "")
|
||||
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
|
||||
string(REGEX MATCHALL "^([1-9][0-9]*)([0-9])[af]?(-real|-virtual)?$" MATCHES
|
||||
${CUDA_ARCH})
|
||||
if(NOT CMAKE_MATCH_0)
|
||||
message(FATAL_ERROR "Invalid CUDA arch format: \"${CUDA_ARCH}\"")
|
||||
endif()
|
||||
set(CUDA_ARCH_MAJOR ${CMAKE_MATCH_1})
|
||||
set(CUDA_ARCH_MINOR ${CMAKE_MATCH_2})
|
||||
set(CUDA_ARCH_POSTFIX ${CMAKE_MATCH_3})
|
||||
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9)
|
||||
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
|
||||
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip build if there is no suitable CUDA arch
|
||||
if(WIN32)
|
||||
set(DEEP_EP_CUDA_ARCHITECTURES "")
|
||||
endif()
|
||||
message(
|
||||
STATUS "deep_ep DEEP_EP_CUDA_ARCHITECTURES: ${DEEP_EP_CUDA_ARCHITECTURES}")
|
||||
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/cuda_architectures.txt
|
||||
"${DEEP_EP_CUDA_ARCHITECTURES}")
|
||||
if(NOT DEEP_EP_CUDA_ARCHITECTURES)
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Prepare files
|
||||
# =============
|
||||
|
||||
# Download DeepEP
|
||||
include(FetchContent)
|
||||
if(DEFINED ENV{GITHUB_MIRROR} AND NOT "$ENV{GITHUB_MIRROR}" STREQUAL "")
|
||||
set(GITHUB_URL "$ENV{GITHUB_MIRROR}")
|
||||
else()
|
||||
set(GITHUB_URL "https://github.com")
|
||||
endif()
|
||||
set(DEEP_EP_URL
|
||||
"${GITHUB_URL}/deepseek-ai/DeepEP/archive/${DEEP_EP_COMMIT}.tar.gz")
|
||||
message(STATUS "deep_ep DEEP_EP_URL: ${DEEP_EP_URL}")
|
||||
FetchContent_Declare(deep_ep_download URL ${DEEP_EP_URL})
|
||||
FetchContent_MakeAvailable(deep_ep_download)
|
||||
set(DEEP_EP_SOURCE_DIR ${deep_ep_download_SOURCE_DIR})
|
||||
|
||||
# Copy and update python files
|
||||
set(DEEP_EP_PYTHON_DEST ${CMAKE_CURRENT_BINARY_DIR}/python/deep_ep)
|
||||
file(REMOVE_RECURSE ${DEEP_EP_PYTHON_DEST})
|
||||
file(MAKE_DIRECTORY ${DEEP_EP_PYTHON_DEST})
|
||||
configure_file(${DEEP_EP_SOURCE_DIR}/LICENSE ${DEEP_EP_PYTHON_DEST}/LICENSE
|
||||
COPYONLY)
|
||||
set(_files __init__.py buffer.py utils.py)
|
||||
foreach(_f IN LISTS _files)
|
||||
set(_src "${DEEP_EP_SOURCE_DIR}/deep_ep/${_f}")
|
||||
set(_dst "${DEEP_EP_PYTHON_DEST}/${_f}")
|
||||
file(READ "${_src}" _content)
|
||||
string(REPLACE "deep_ep_cpp" "tensorrt_llm.deep_ep_cpp_tllm" _content
|
||||
"${_content}")
|
||||
string(
|
||||
PREPEND
|
||||
_content
|
||||
"# Adapted from https://github.com/deepseek-ai/DeepEP/blob/${DEEP_EP_COMMIT}/deep_ep/${_f}\n"
|
||||
)
|
||||
file(WRITE "${_dst}" "${_content}")
|
||||
set_property(
|
||||
DIRECTORY
|
||||
APPEND
|
||||
PROPERTY CMAKE_CONFIGURE_DEPENDS ${_src})
|
||||
endforeach()
|
||||
|
||||
# Delete stale nvshmem on patch update
|
||||
set(NVSHMEM_STAMP_FILE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_stamp.txt)
|
||||
file(SHA256 ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch NVSHMEM_PATCH_HASH)
|
||||
file(SHA256 ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch
|
||||
NVSHMEM_PATCH_2_HASH)
|
||||
set(NVSHMEM_STAMP_CONTENT "${NVSHMEM_URL_HASH}")
|
||||
string(APPEND NVSHMEM_STAMP_CONTENT " PATCH_COMMAND v1")
|
||||
string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_HASH}")
|
||||
string(APPEND NVSHMEM_STAMP_CONTENT " 103")
|
||||
string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_2_HASH}")
|
||||
set(OLD_NVSHMEM_STAMP_CONTENT "")
|
||||
if(EXISTS ${NVSHMEM_STAMP_FILE})
|
||||
file(READ ${NVSHMEM_STAMP_FILE} OLD_NVSHMEM_STAMP_CONTENT)
|
||||
endif()
|
||||
if(NOT OLD_NVSHMEM_STAMP_CONTENT STREQUAL NVSHMEM_STAMP_CONTENT)
|
||||
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_project-prefix)
|
||||
file(WRITE ${NVSHMEM_STAMP_FILE} "${NVSHMEM_STAMP_CONTENT}")
|
||||
endif()
|
||||
set_property(
|
||||
DIRECTORY APPEND
|
||||
PROPERTY CMAKE_CONFIGURE_DEPENDS
|
||||
${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch)
|
||||
|
||||
# Add NVSHMEM
|
||||
# ===========
|
||||
|
||||
# NVSHMEM only works with GCC. Building NVSHMEM with Clang results in
|
||||
# compilation errors. Using NVSHMEM with Clang results in slow builds and device
|
||||
# link issues.
|
||||
if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
set(CMAKE_C_COMPILER gcc)
|
||||
set(CMAKE_CXX_COMPILER g++)
|
||||
set(CMAKE_CUDA_HOST_COMPILER g++)
|
||||
endif()
|
||||
|
||||
# Add nvshmem external project
|
||||
include(ExternalProject)
|
||||
ExternalProject_Add(
|
||||
nvshmem_project
|
||||
URL file://${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_src_3.2.5-1.txz
|
||||
URL_HASH ${NVSHMEM_URL_HASH}
|
||||
PATCH_COMMAND patch -p1 --forward --batch -i
|
||||
${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch
|
||||
COMMAND sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i
|
||||
src/CMakeLists.txt
|
||||
COMMAND patch -p1 --forward --batch -i
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch
|
||||
CMAKE_CACHE_ARGS
|
||||
-DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER}
|
||||
-DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}
|
||||
-DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER}
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}
|
||||
-DCMAKE_CUDA_ARCHITECTURES:STRING=${DEEP_EP_CUDA_ARCHITECTURES}
|
||||
-DCMAKE_CUDA_HOST_COMPILER:STRING=${CMAKE_CUDA_HOST_COMPILER}
|
||||
-DCMAKE_CUDA_COMPILER_LAUNCHER:STRING=${CMAKE_CUDA_COMPILER_LAUNCHER}
|
||||
-DNVSHMEM_BUILD_EXAMPLES:BOOL=0
|
||||
-DNVSHMEM_BUILD_PACKAGES:BOOL=0
|
||||
-DNVSHMEM_BUILD_TESTS:BOOL=0
|
||||
-DNVSHMEM_IBGDA_SUPPORT:BOOL=1
|
||||
-DNVSHMEM_IBRC_SUPPORT:BOOL=0
|
||||
-DNVSHMEM_MPI_SUPPORT:BOOL=0
|
||||
-DNVSHMEM_PMIX_SUPPORT:BOOL=0
|
||||
-DNVSHMEM_SHMEM_SUPPORT:BOOL=0
|
||||
-DNVSHMEM_TIMEOUT_DEVICE_POLLING:BOOL=0
|
||||
-DNVSHMEM_UCX_SUPPORT:BOOL=0
|
||||
-DNVSHMEM_USE_GDRCOPY:BOOL=0
|
||||
-DNVSHMEM_USE_NCCL:BOOL=0
|
||||
INSTALL_COMMAND ""
|
||||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build
|
||||
BUILD_BYPRODUCTS
|
||||
${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a)
|
||||
add_library(nvshmem_project::nvshmem STATIC IMPORTED)
|
||||
add_dependencies(nvshmem_project::nvshmem nvshmem_project)
|
||||
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include)
|
||||
set_target_properties(
|
||||
nvshmem_project::nvshmem
|
||||
PROPERTIES IMPORTED_LOCATION
|
||||
${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a
|
||||
INTERFACE_INCLUDE_DIRECTORIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include)
|
||||
|
||||
# Add DeepEP cpp
|
||||
# ==============
|
||||
|
||||
# Let CMake generate `fatbinData` for CUDA separable compilation. Set to FALSE
|
||||
# or TRUE are both OK, but it generates `code=lto_90a` rather than `code=sm_90a`
|
||||
# for arch `90a-real` if set to TRUE.
|
||||
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION FALSE)
|
||||
|
||||
# Find torch_python
|
||||
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
|
||||
HINTS ${TORCH_INSTALL_PREFIX}/lib)
|
||||
|
||||
# Add deep_ep_cpp_tllm
|
||||
file(GLOB_RECURSE SRC_CPP ${DEEP_EP_SOURCE_DIR}/csrc/*.cpp)
|
||||
file(GLOB_RECURSE SRC_CU ${DEEP_EP_SOURCE_DIR}/csrc/*.cu)
|
||||
pybind11_add_module(deep_ep_cpp_tllm ${SRC_CPP} ${SRC_CU})
|
||||
set_target_properties(
|
||||
deep_ep_cpp_tllm
|
||||
PROPERTIES CXX_STANDARD_REQUIRED ON
|
||||
CUDA_STANDARD_REQUIRED ON
|
||||
CXX_STANDARD 17
|
||||
CUDA_STANDARD 17
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
CUDA_ARCHITECTURES "${DEEP_EP_CUDA_ARCHITECTURES}"
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/deep_ep_cpp_tllm.version
|
||||
INSTALL_RPATH "$ORIGIN/libs/nvshmem;${TORCH_INSTALL_PREFIX}/lib"
|
||||
BUILD_WITH_INSTALL_RPATH TRUE)
|
||||
target_compile_options(
|
||||
deep_ep_cpp_tllm
|
||||
PRIVATE ${TORCH_CXX_FLAGS} -O3 $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-O3>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:--ptxas-options=--register-usage-level=10>)
|
||||
target_compile_definitions(
|
||||
deep_ep_cpp_tllm PRIVATE DISABLE_AGGRESSIVE_PTX_INSTRS
|
||||
TORCH_EXTENSION_NAME=deep_ep_cpp_tllm)
|
||||
target_link_libraries(
|
||||
deep_ep_cpp_tllm PRIVATE nvshmem_project::nvshmem ${TORCH_LIBRARIES}
|
||||
${TORCH_PYTHON_LIB})
|
||||
target_link_options(
|
||||
deep_ep_cpp_tllm PRIVATE
|
||||
-Wl,--version-script,${CMAKE_CURRENT_SOURCE_DIR}/deep_ep_cpp_tllm.version
|
||||
-Wl,--no-undefined-version)
|
||||
|
||||
# Set targets
|
||||
# ===========
|
||||
add_dependencies(deep_ep deep_ep_cpp_tllm nvshmem_project)
|
||||
8
cpp/tensorrt_llm/deep_ep/README.md
Normal file
8
cpp/tensorrt_llm/deep_ep/README.md
Normal file
@ -0,0 +1,8 @@
|
||||
How to generate `nvshmem_fast_build.patch`?
|
||||
|
||||
1. Build the project without applying the `nvshmem_fast_build.patch`.
|
||||
2. Link NVSHMEM to DeepEP with one NVSHMEM object file omitted.
|
||||
3. Repeat step 2 until no more object files can be omitted.
|
||||
4. Remove the unused files from NVSHMEM's `CMakelists.txt`, and save the differences as `nvshmem_fast_build.patch`.
|
||||
|
||||
The script `strip_nvshmem_helper.py` automatically performs steps 2 and 3.
|
||||
4
cpp/tensorrt_llm/deep_ep/deep_ep_cpp_tllm.version
Normal file
4
cpp/tensorrt_llm/deep_ep/deep_ep_cpp_tllm.version
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
global: PyInit_deep_ep_cpp_tllm;
|
||||
local: *;
|
||||
};
|
||||
66
cpp/tensorrt_llm/deep_ep/nvshmem_fast_build.patch
Normal file
66
cpp/tensorrt_llm/deep_ep/nvshmem_fast_build.patch
Normal file
@ -0,0 +1,66 @@
|
||||
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
|
||||
index cba899bba..c27337601 100644
|
||||
--- a/src/CMakeLists.txt
|
||||
+++ b/src/CMakeLists.txt
|
||||
@@ -264,48 +264,20 @@ set(NVSHMEM_HOST_SOURCES_NOMAXREGCOUNT
|
||||
host/comm/rma.cu
|
||||
host/stream/comm/quiet_on_stream.cu
|
||||
host/stream/comm/cuda_interface_sync.cu
|
||||
- host/stream/coll/alltoall/alltoall.cu
|
||||
host/stream/coll/barrier/barrier.cu
|
||||
- host/stream/coll/broadcast/broadcast.cu
|
||||
- host/stream/coll/fcollect/fcollect.cu
|
||||
- host/stream/coll/rdxn/reduce_and.cu
|
||||
- host/stream/coll/rdxn/reduce_or.cu
|
||||
- host/stream/coll/rdxn/reduce_xor.cu
|
||||
- host/stream/coll/rdxn/reduce_min.cu
|
||||
host/stream/coll/rdxn/reduce_max.cu
|
||||
- host/stream/coll/rdxn/reduce_prod.cu
|
||||
- host/stream/coll/rdxn/reduce_sum.cu
|
||||
host/stream/coll/rdxn/reduce_team.cu
|
||||
- host/stream/coll/reducescatter/reducescatter_and.cu
|
||||
- host/stream/coll/reducescatter/reducescatter_or.cu
|
||||
- host/stream/coll/reducescatter/reducescatter_xor.cu
|
||||
- host/stream/coll/reducescatter/reducescatter_min.cu
|
||||
- host/stream/coll/reducescatter/reducescatter_max.cu
|
||||
- host/stream/coll/reducescatter/reducescatter_prod.cu
|
||||
- host/stream/coll/reducescatter/reducescatter_sum.cu
|
||||
)
|
||||
|
||||
set(NVSHMEM_HOST_SOURCES
|
||||
host/bootstrap/bootstrap.cpp
|
||||
host/bootstrap/bootstrap_loader.cpp
|
||||
host/coll/cpu_coll.cpp
|
||||
- host/coll/alltoall/alltoall.cpp
|
||||
- host/coll/alltoall/alltoall_on_stream.cpp
|
||||
host/coll/barrier/barrier.cpp
|
||||
host/coll/barrier/barrier_on_stream.cpp
|
||||
- host/coll/broadcast/broadcast.cpp
|
||||
- host/coll/broadcast/broadcast_on_stream.cpp
|
||||
- host/coll/fcollect/fcollect.cpp
|
||||
- host/coll/fcollect/fcollect_on_stream.cpp
|
||||
- host/coll/rdxn/rdxn.cpp
|
||||
- host/coll/rdxn/rdxn_on_stream.cpp
|
||||
- host/coll/reducescatter/reducescatter.cpp
|
||||
- host/coll/reducescatter/reducescatter_on_stream.cpp
|
||||
host/comm/putget.cpp
|
||||
- host/comm/fence.cpp
|
||||
host/comm/quiet.cpp
|
||||
host/comm/sync.cpp
|
||||
- host/comm/amo.cpp
|
||||
host/proxy/proxy.cpp
|
||||
host/transport/transport.cpp
|
||||
host/transport/p2p/p2p.cpp
|
||||
@@ -1006,3 +978,12 @@ set(CPACK_RPM_PACKAGE_REQUIRES_PREUN "/sbin/ldconfig")
|
||||
|
||||
include(CPack)
|
||||
# End Installation definitions
|
||||
+
|
||||
+set_target_properties(
|
||||
+ git_commit
|
||||
+ nvshmem_device_project
|
||||
+ nvshmem_bootstrap_pmi
|
||||
+ nvshmem_bootstrap_pmi2
|
||||
+ nvshmem_host
|
||||
+ nvshmem-info
|
||||
+ PROPERTIES EXCLUDE_FROM_ALL TRUE)
|
||||
3
cpp/tensorrt_llm/deep_ep/nvshmem_src_3.2.5-1.txz
Normal file
3
cpp/tensorrt_llm/deep_ep/nvshmem_src_3.2.5-1.txz
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a
|
||||
size 618175
|
||||
61
cpp/tensorrt_llm/deep_ep/strip_nvshmem_helper.py
Normal file
61
cpp/tensorrt_llm/deep_ep/strip_nvshmem_helper.py
Normal file
@ -0,0 +1,61 @@
|
||||
# A helper script to detect unused NVSHMEM object files.
|
||||
#
|
||||
# The script links NVSHMEM to DeepEP with one object file removed at a time and
|
||||
# checks whether there are any undefined symbols. See README.md for details.
|
||||
# This script is not tested or QA'ed, so you may need to update this script if
|
||||
# the project structure changes or compilation options change.
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
project_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
|
||||
# Run `find cpp/build | grep kernels/internode_ll.cu.o$` to get the directory
|
||||
deep_ep_obj_dir = project_dir / "cpp/build/tensorrt_llm/deep_ep/CMakeFiles/deep_ep_cpp_tllm.dir/__/__/_deps/deep_ep_download-src/csrc"
|
||||
assert deep_ep_obj_dir.is_dir()
|
||||
|
||||
# Run `find cpp/build | grep host/bootstrap/bootstrap.cpp.o$` to get the directory
|
||||
# Please set it to `nvshmem.dir` rather than `nvshmem_host.dir`
|
||||
nvshmem_obj_dir = project_dir / "cpp/build/tensorrt_llm/deep_ep/nvshmem-build/src/CMakeFiles/nvshmem.dir"
|
||||
assert nvshmem_obj_dir.is_dir()
|
||||
|
||||
# Parse the `-gencode` arguments
|
||||
with (project_dir /
|
||||
"cpp/build/tensorrt_llm/deep_ep/cuda_architectures.txt").open() as f:
|
||||
cuda_architectures = f.read()
|
||||
pattern = re.compile(r'^([1-9][0-9]*[0-9][af]?)(-real|-virtual)?$')
|
||||
gencode_args = []
|
||||
for cuda_arch in cuda_architectures.split(";"):
|
||||
matches = re.match(pattern, cuda_arch)
|
||||
assert matches is not None, f"Invalid cuda arch \"{cuda_arch}\""
|
||||
sm_version = matches.group(1)
|
||||
postfix = matches.group(2) or ""
|
||||
code = {
|
||||
"": f"[compute_{sm_version},sm_{sm_version}]",
|
||||
"-real": f"[sm_{sm_version}]",
|
||||
"-virtual": f"[compute_{sm_version}]",
|
||||
}[postfix]
|
||||
gencode_args.append(f"-gencode=arch=compute_{sm_version},{code=:s}")
|
||||
|
||||
temp_dir = project_dir / "cpp/build/tensorrt_llm/deep_ep/strip_nvshmem_helper"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
ranlib = temp_dir / "liba.a"
|
||||
if ranlib.exists():
|
||||
ranlib.unlink()
|
||||
|
||||
deep_ep_obj_list = sorted(deep_ep_obj_dir.glob("kernels/**/*.o"))
|
||||
nvshmem_obj_set = set(nvshmem_obj_dir.glob("**/*.o"))
|
||||
for exclude_obj in sorted(nvshmem_obj_set):
|
||||
# Create liba.a with one object file removed
|
||||
subprocess.check_call(
|
||||
["ar", "rcs", ranlib, *(nvshmem_obj_set - {exclude_obj})])
|
||||
# Test whether there are undefined symbols
|
||||
res = subprocess.call([
|
||||
"/usr/local/cuda/bin/nvcc", *gencode_args, "-Xlinker", "--no-undefined",
|
||||
"-shared", *deep_ep_obj_list, ranlib, "-o", temp_dir / "a.out"
|
||||
])
|
||||
# If there are no undefined symbols, print "-" to indicate the file can be omitted
|
||||
print("-" if res == 0 else "+",
|
||||
str(exclude_obj.relative_to(nvshmem_obj_dir))[:-2])
|
||||
# Unlink the archive file because `ar` appends existing archives
|
||||
ranlib.unlink()
|
||||
@ -18,7 +18,7 @@ set(SRCS
|
||||
cache_transmission/mpi_utils/connection.cpp
|
||||
cache_transmission/agent_utils/connection.cpp
|
||||
cache_transmission/transferAgent.cpp
|
||||
cache_transmission/cacheConcatenate.cu
|
||||
cache_transmission/cacheSplitConcat.cu
|
||||
contextPhaseParams.cpp
|
||||
debugConfig.cpp
|
||||
decodingConfig.cpp
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1540
cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
Normal file
1540
cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -31,13 +31,14 @@
|
||||
|
||||
namespace tensorrt_llm::executor::kv_cache
|
||||
{
|
||||
|
||||
struct TargetRanksInfo
|
||||
{
|
||||
int mDomainPPSize;
|
||||
int mDomainTPSize;
|
||||
std::vector<int> mIRanks;
|
||||
int mDuplicateHeadFactor;
|
||||
int mPeerDuplicateHeadFactor;
|
||||
int mDupHeadFactor;
|
||||
int mPeerDupHeadFactor;
|
||||
};
|
||||
|
||||
TargetRanksInfo targetIRanks(
|
||||
@ -46,17 +47,20 @@ TargetRanksInfo targetIRanks(
|
||||
TargetRanksInfo TargetRanksInfoForDP(
|
||||
kv_cache::CacheState const& peerCacheState, kv_cache::CacheState const& selfCacheState, int selfRank);
|
||||
|
||||
void concatenateKVCacheDispatch(runtime::ITensor::SharedPtr* inputBlocks, int inputBlockNum,
|
||||
void concatKVCacheDispatch(runtime::ITensor::SharedPtr* inputBlocks, int inputBlockNum,
|
||||
std::vector<int> const& inputRanks, kv_cache::CacheState const& peerCacheState,
|
||||
runtime::ITensor::SharedPtr* outputBlocks, int outputBlockNum, int selfRank,
|
||||
kv_cache::CacheState const& selfCacheState, runtime::BufferManager const& bufferManager);
|
||||
|
||||
nvinfer1::Dims makeShapeFromCacheState(kv_cache::CacheState const& cacheState);
|
||||
|
||||
void splitKVCacheDispatch(std::vector<runtime::ITensor::SharedPtr> const& kVCacheBlocks,
|
||||
void splitKVCacheDispatch(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> const& kVCacheBlocksPerWindow,
|
||||
std::vector<runtime::ITensor::SharedPtr>& ouputSplitBlocks, kv_cache::CacheState const& peerCacheState,
|
||||
kv_cache::CacheState const& selfCacheState, int selfIdx, runtime::BufferManager const& bufferManager);
|
||||
|
||||
void concatenateKvCacheV2Dispatch(std::vector<runtime::ITensor::SharedPtr> const& inputSplitBlocks,
|
||||
std::vector<runtime::ITensor::SharedPtr>& outputKvCacheBlocks, kv_cache::CacheState const& peerCacheState,
|
||||
kv_cache::CacheState const& selfCacheState, int selfIdx, runtime::BufferManager const& bufferManager);
|
||||
void concatKvCacheV2Dispatch(std::vector<runtime::ITensor::SharedPtr> const& inputSplitBlocksPerWindow,
|
||||
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>& outputKvCacheBlocksPerWindow,
|
||||
kv_cache::CacheState const& peerCacheState, kv_cache::CacheState const& selfCacheState, int selfIdx,
|
||||
runtime::BufferManager const& bufferManager);
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
@ -132,9 +132,10 @@ std::optional<std::shared_ptr<KVCacheEventManager>> Executor::getKVCacheEventMan
|
||||
return mImpl->getKVCacheEventManager();
|
||||
}
|
||||
|
||||
KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data)
|
||||
KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize)
|
||||
: eventId{eventId}
|
||||
, data{std::move(data)}
|
||||
, windowSize{windowSize}
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6173ab315983d8844078fbddd8410ea6b99d30092e5c6dc467fda10300620b74
|
||||
size 601111
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f32d82ae86c521360042b14f1b6a6d79b2bcfe23f6d129af99df591787007dee
|
||||
size 912898
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f7bf690286a3f532c5375cd76db7383ba552a59f60eba114584e5cde0043834a
|
||||
size 1385720
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f73d1f5e15a69c4455a57a351f856f544b097543991c17c0620917d1e1fd3fad
|
||||
size 1456760
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e56cb50ecd9aac19bd3af9b65ec3f0e04aef868596dc625939a0e4ad0693ff13
|
||||
size 1456760
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1aa3a4f9101c656e57a9053f6f669f36d897e97d29d5c0889b0fa74478a315da
|
||||
size 1979300
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1ae2f8df40a25cb8b09f6ce2fb838953e8bbab1ad6fb71a372739d9a8a6636ff
|
||||
size 1389654
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c93bb4f2f953d9f0d46139642a87a9955c338cf00d757d95c91d02cf0671e329
|
||||
size 1409386
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:087062c343a9d04afda590db19761e37a7ad53740f4a1919e86dc439d86e9d37
|
||||
size 1409386
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9d0e082555cbda07638de0d1d838269437f7100e6f12afd98c3a3dc378d2aa7c
|
||||
size 1948502
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5c46353a6c00c154ed5d7bbb52c56b42f8dccf5a700f928243029ccfafee3013
|
||||
size 308265
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f4f0d5736d6801f3614c72f31581c1e227cf51eafb60e009b47f267982f36136
|
||||
size 292477
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9d1c4f9a5c53d3f226dda0c2f1dd53afac4f3719731130af6a9ce704e9b55d0e
|
||||
size 515083
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8662ebc259db8989f193c69e1aea9bc2de7da97d8f0564ca023d77123cfc05d8
|
||||
size 679266
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33c76fd50a8a68c154e3c5016767f1deef66b9b369885fce6fe5da1ecabe83b5
|
||||
size 742412
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:69eef116cc9ceeb142af8d83bf9463fd1678539ac11915712be7b7123f71aed8
|
||||
size 782692
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:80da78fcf36253cfa63bc5cd7891cf4f79ed32ade50c3bf4c6ab209abb77cf46
|
||||
size 780300
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:798951dbc53219e7402642bd6b49a5eeb01010ff76a0ab8ae99f519effc86080
|
||||
size 980002
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:69aef72f514c7338449e301205aca1a411ed466f90801410547d241f2147f339
|
||||
size 507977
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:737387664ae52b4874af7971c93f70942f17a559dd68dac553b59be682183d60
|
||||
size 507977
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:23785e6d85c7a93d7a0f8691d79a6de1c953fbb4ee057cb8ac13a10c0b1ed6d6
|
||||
size 517449
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ffefd85f6395becfe5b80d863761617fea35167138b738d924718efcb1736f49
|
||||
size 499283
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:346b1557eee6957ed0cf3b793c86b78dbcaa799bc806798f15c28eaf6581e110
|
||||
size 184391
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fec694c26cdda7b808b836a7b18918b56eca406c0d42108cec6c60c31d882209
|
||||
size 184391
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:039256731f20528aab02a8df3729680d8cc9c9bb03b89047724b58c185d65f74
|
||||
size 665832
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a4bad8fa30b04f0f3a13edc310a6b9eb6e99ca31cad75a15410e233327babdbd
|
||||
size 674516
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:001374158c745bc46dec1996a7d1ba0a3b537c8c354ecd6938e5ef9d93339bcc
|
||||
size 725056
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4bd5818a16a40b85edb46f08b23b78adcaf3dac0defcc86000fcf0589a6874f1
|
||||
size 722664
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ed8dbc734d33ec27051eac487109d50ef8c63edb6471b4f8b0fd403d807bc173
|
||||
size 932628
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b22e753cfbcf3314884fc4557c973d6cf2486cef891f0ed74a680a3e34ffac20
|
||||
size 638204
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8797953ca8e515e35a955de5e7a173dd2f83be3c807844fb4c4f04128c4840b8
|
||||
size 161497
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:65cf71ff8b657165ff727d1bd90266042fcf1c31e0882953415d9f66e14b8eb3
|
||||
size 161497
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bc72689d27d04bbff63953c8772069ffde934aac9017fb22be9b27f056fa826d
|
||||
size 488229
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:960e14c1154c414028e1eec2b88258cd5d6d4db05ad0905836eb59527f0bc7dc
|
||||
size 500859
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:30f39bd5e745d016a62d93b5bff3b86eba92b91a8391579dac8e9ff3f43b4c89
|
||||
size 232533
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0a7c5b8d27d0e3470bf7a5600722e8c9cb977802746ce529b9224b2aaf197c40
|
||||
size 231721
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:beb4939e0f07e964f53db3bc7f051e124a89d684caacbf53b4d882049c979541
|
||||
size 287763
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:66dcf4cefafc80111d5c517466d3be1b96fdef31975a7fbd0afbe903b90e8694
|
||||
size 231731
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:341f1667912db3b3cb2f5b98e41c9f41d5458e47c3d0cfd056a4191a81f550ae
|
||||
size 230917
|
||||
@ -83,8 +83,8 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
|
||||
FusedMHARunnerV2::FusedMHARunnerV2(MHARunnerFixedParams fixedParams)
|
||||
: mFixedParams(fixedParams)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(mSM == kSM_80 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_90 || mSM == kSM_100 || mSM == kSM_120),
|
||||
TLLM_CHECK_WITH_INFO((mSM == kSM_80 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_90 || mSM == kSM_100
|
||||
|| mSM == kSM_120 || mSM == kSM_121),
|
||||
"Unsupported architecture");
|
||||
TLLM_CHECK_WITH_INFO((mFixedParams.dataType == DATA_TYPE_FP16 || mFixedParams.dataType == DATA_TYPE_BF16
|
||||
|| mFixedParams.dataType == DATA_TYPE_E4M3),
|
||||
@ -305,7 +305,7 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
|
||||
bool const isSm80 = (mSM == kSM_80);
|
||||
bool const isSm89 = (mSM == kSM_89);
|
||||
bool const isSm100 = (mSM == kSM_100);
|
||||
bool const isSm120 = (mSM == kSM_120);
|
||||
bool const isSm120f = (mSM == kSM_120 || mSM == kSM_121);
|
||||
|
||||
// Sliding_or_chunked_causal mask.
|
||||
if ((runnerParams.kvSeqLen > runnerParams.slidingWindowSize
|
||||
@ -356,7 +356,7 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
|
||||
mLaunchParams.kernel_s = 0;
|
||||
mLaunchParams.force_unroll = true;
|
||||
// enable tiled kernels on Ampere/Ada
|
||||
if ((isSm89 || isSm120) && mFixedParams.dataType == DATA_TYPE_E4M3)
|
||||
if ((isSm89 || isSm120f) && mFixedParams.dataType == DATA_TYPE_E4M3)
|
||||
{
|
||||
// so far Ada QMMA only supports non-tiled kernels.
|
||||
mLaunchParams.granular_tiling = false;
|
||||
@ -368,12 +368,12 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
|
||||
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
|
||||
mLaunchParams.granular_tiling = false;
|
||||
}
|
||||
else if ((isSm8x || isSm120) && mFixedParams.headSize < 256)
|
||||
else if ((isSm8x || isSm120f) && mFixedParams.headSize < 256)
|
||||
{
|
||||
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
|
||||
mLaunchParams.granular_tiling = false;
|
||||
}
|
||||
else if (isSm80 || isSm8x || isSm100 || isSm120)
|
||||
else if (isSm80 || isSm8x || isSm100 || isSm120f)
|
||||
{
|
||||
// otherwise, choose tiled kernel for Ampere/Ada/Gb20x
|
||||
mLaunchParams.granular_tiling = true;
|
||||
|
||||
@ -552,6 +552,10 @@ uint64_t FusedMultiHeadAttentionXMMAKernelV2::hashFromParams(
|
||||
|
||||
FusedMultiHeadAttentionXMMAKernelV2 const* getXMMAKernelsV2(Data_type inputType, Data_type outputType, unsigned int sm)
|
||||
{
|
||||
if (sm == kSM_121)
|
||||
{
|
||||
sm = kSM_120;
|
||||
}
|
||||
return FusedMHAKernelFactoryV2::Get().getXMMAKernels(sMhaKernelMetaInfosV2,
|
||||
sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), inputType, outputType, sm);
|
||||
}
|
||||
|
||||
@ -70,8 +70,10 @@ function(process_target target_name enable_hopper enable_blackwell)
|
||||
PUBLIC COMPILE_HOPPER_TMA_GROUPED_GEMMS)
|
||||
endif()
|
||||
|
||||
if(${enable_blackwell} AND ("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
|
||||
OR "120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG))
|
||||
if(${enable_blackwell}
|
||||
AND ("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
|
||||
OR "120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
|
||||
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG))
|
||||
|
||||
if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
||||
# No kernels should be parsed, unless blackwell is specified. This is a
|
||||
@ -81,7 +83,8 @@ function(process_target target_name enable_hopper enable_blackwell)
|
||||
target_compile_definitions(${target_name}
|
||||
PUBLIC COMPILE_BLACKWELL_TMA_GROUPED_GEMMS)
|
||||
endif()
|
||||
if("120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
||||
if("120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
|
||||
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
||||
target_compile_definitions(${target_name}
|
||||
PRIVATE COMPILE_BLACKWELL_SM120_TMA_GEMMS)
|
||||
target_compile_definitions(
|
||||
@ -116,7 +119,7 @@ function(add_instantiations library base_dir)
|
||||
|
||||
if(${ARCH} EQUAL 90)
|
||||
process_target(${TARGET_NAME} true false)
|
||||
elseif(${ARCH} EQUAL 100)
|
||||
elseif(${ARCH} GREATER_EQUAL 100)
|
||||
process_target(${TARGET_NAME} false true)
|
||||
endif()
|
||||
endif()
|
||||
@ -246,11 +249,14 @@ endif()
|
||||
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
add_library(
|
||||
ar_gemm_src STATIC
|
||||
${ARGEMM_SRC_CU}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cpp)
|
||||
${ARGEMM_SRC_CU} ${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cu)
|
||||
target_include_directories(
|
||||
ar_gemm_src
|
||||
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../internal_cutlass_kernels/include)
|
||||
if(ENABLE_NVSHMEM)
|
||||
target_link_libraries(ar_gemm_src PRIVATE nvshmem::nvshmem_host
|
||||
nvshmem::nvshmem_device)
|
||||
endif()
|
||||
set_cuda_architectures(ar_gemm_src 90 100f)
|
||||
endif()
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ public:
|
||||
// Epilogue
|
||||
////////////////
|
||||
using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true, true>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true>;
|
||||
using EpilogueScheduleType = typename MmaAdapter<MmaType, IsFP4>::EpilogueSchedule;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOp
|
||||
|
||||
@ -100,8 +100,7 @@ public:
|
||||
using RasterOrderOptions =
|
||||
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */,
|
||||
true /* membar.gpu */>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */>;
|
||||
|
||||
// 16B alignment for TMA
|
||||
static constexpr int AlignmentA = 16 / sizeof(ElementA);
|
||||
|
||||
@ -201,7 +201,7 @@ public:
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto [m, n, k, l] = tile_coord;
|
||||
|
||||
if (!tile_valid(m, n) || params_ptr->world_size == 1)
|
||||
if (!tile_valid(m, n) || params_ptr->world_size <= 2)
|
||||
{
|
||||
return; // nothing to do
|
||||
}
|
||||
@ -212,7 +212,7 @@ public:
|
||||
|
||||
// Wait for all multicast writes to be visible to us.
|
||||
// This is safe between phases.
|
||||
SystemBarrier::arrive_and_wait(
|
||||
SystemBarrier::arrive_and_wait<cuda::thread_scope::thread_scope_system>(
|
||||
params_ptr->barrier_params_final_sync, thread_idx, tile_index, params_ptr->rank, params_ptr->world_size);
|
||||
}
|
||||
|
||||
@ -297,13 +297,20 @@ public:
|
||||
Tensor tGR_gD1_vec = zipped_divide(tGR_gD1(_, _, _, red_m, red_n), Vec);
|
||||
Tensor tRG_gOut_vec = zipped_divide(tRG_gOut(_, _, _, red_m, red_n), Vec);
|
||||
|
||||
auto pred_fn
|
||||
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
|
||||
// Create predicate tensor for bounds checking
|
||||
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_pD_vec)), Stride<_1>{});
|
||||
|
||||
// Set predicate values based on coordinate bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred_tensor); ++i)
|
||||
{
|
||||
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
|
||||
}
|
||||
|
||||
// Read from self.
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD0_vec, tGR_rD0_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD0_vec, tGR_rD0_vec);
|
||||
// Read from remote.
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD1_vec, tGR_rD1_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD1_vec, tGR_rD1_vec);
|
||||
// Reduce
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tGR_rD0_vec); i++)
|
||||
@ -311,7 +318,7 @@ public:
|
||||
tGR_rD0_vec(i) += tGR_rD1_vec(i);
|
||||
}
|
||||
// store to self.
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_rD0_vec, tRG_gOut_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_rD0_vec, tRG_gOut_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -386,13 +393,21 @@ public:
|
||||
Tensor tGR_gD_vec = zipped_divide(tGR_gD(_, _, _, red_m, red_n), Vec);
|
||||
Tensor tRG_gD_vec = zipped_divide(tRG_gD(_, _, _, red_m, red_n), Vec);
|
||||
Tensor tGR_pD_vec = zipped_divide(tGR_pD(_, _, _, red_m, red_n), Vec);
|
||||
// problem shape bounds check
|
||||
auto pred_fn
|
||||
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
|
||||
|
||||
// Create predicate tensor for bounds checking
|
||||
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_gD_vec)), Stride<_1>{});
|
||||
|
||||
// Set predicate values based on coordinate bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred_tensor); ++i)
|
||||
{
|
||||
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
|
||||
}
|
||||
|
||||
// load-reduce in switch
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD_vec, tGR_rD_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD_vec, tGR_rD_vec);
|
||||
// store switch multicast
|
||||
cute::copy_if(CopyAtomR2G{}, pred_fn, tGR_rD_vec, tRG_gD_vec);
|
||||
cute::copy_if(CopyAtomR2G{}, pred_tensor, tGR_rD_vec, tRG_gD_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,7 +171,7 @@ struct Sm100AllReduceArrive
|
||||
tma_store_wait<0>();
|
||||
|
||||
int tile_idx = params_ptr->tile_layout(m, n);
|
||||
SystemBarrier::arrive_inc(
|
||||
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
|
||||
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -268,7 +268,7 @@ struct Sm90AuxAllReduce
|
||||
tma_store_wait<0>();
|
||||
|
||||
int tile_idx = params_ptr->tile_layout(m, n);
|
||||
SystemBarrier::arrive_inc(
|
||||
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
|
||||
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
|
||||
}
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user