Compare commits

...

382 Commits

Author SHA1 Message Date
benzh-2025
6df2c8a074
[None][feat] add fp4 gemm + allreduce (#9729)
Signed-off-by: benzh 
Signed-off-by: benzh-2025
2026-01-13 21:11:13 +08:00
Guoming Zhang
c1b0b7350f
[None][test] Unwaive qwen3 next test case. (#9877)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2026-01-13 20:42:31 +08:00
Tailing Yuan
38296a472b
[None][feat] Layer-wise benchmarks: make model init more general and support weights loading (#10562)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
2026-01-13 19:17:03 +08:00
mpikulski
50c78179dd
[TRTLLM-8425][doc] document Torch Sampler details (#10606)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
2026-01-13 12:01:20 +01:00
Erin
55580f8ec1
[NVBUG-5670458][chore] Unwaive lp tests (#10524)
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Erin <14718778+hchings@users.noreply.github.com>
2026-01-13 04:31:27 -05:00
Void
7d16f3a28b
[https://nvbugs/5788127][fix] Use uint64_t as the dtype of lamport_buffer_size to avoid overflow (#10499)
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
2026-01-13 17:16:22 +08:00
Guoming Zhang
bdaee87895
[TRTLLM-10060][feat] Enable attention dp for Nemotron Super v3. (#10347)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2026-01-13 17:13:55 +08:00
JunyiXu-nv
e291a834db
[TRTLLM-8462][feat] Support GET/DELETE v1/responses/{response_id} (#9937)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
2026-01-13 03:57:14 -05:00
Yuxian Qiu
04b112651b
[None][feat] Hang detection for executor loop and worker. (#10480)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2026-01-13 02:34:32 -05:00
Yiteng Niu
50c22b80d7
[None][infra] Update allowlist 2026.01.08 (#10535)
Signed-off-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com>
2026-01-13 15:28:53 +08:00
tburt-nv
7d41475954
[None][infra] try removing shared cache dir mount (#10609)
Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com>
2026-01-13 15:07:12 +08:00
JennyLiu
2967d299fb
[TRTLLM-10271][test] Add Spark QA functional and performance cases (#10564)
Signed-off-by: Jenny Liu <JennyLiu-nv+JennyLiu@users.noreply.github.com>
Co-authored-by: Jenny Liu <JennyLiu-nv+JennyLiu@users.noreply.github.com>
2026-01-13 13:20:15 +08:00
TensorRT LLM
ba1cb6831d [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-13 03:08:08 +00:00
fredricz-20070104
bbe535fddf
[None][chore] Fix disagg assert (#10596)
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
2026-01-12 21:39:57 -05:00
xxi
ba1037ca4a
[https://nvbugs/5762336][fix] support to parse the keyword modules_to_not_convert of the HF model config" (#10527)
Signed-off-by: xxi <xxi@nvidia.com>
2026-01-12 20:21:01 -05:00
Iman Tabrizian
48b09e5a25
[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg (#10111)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
2026-01-12 18:23:26 -05:00
Gal Hubara-Agam
18a33764b5
[None][chore] Print correct backend name in benchmark report (#10597)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
2026-01-12 14:46:00 -05:00
Anish Shanbhag
dacc881993
[https://nvbugs/5761391][fix] Use correct model names for config database regression tests (#10192)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-12 10:55:07 -08:00
Suyog Gupta
a1385243e1
[#10580][fix] re-enable NemotronH MOE MMLU test (#10594)
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
2026-01-12 09:26:07 -08:00
Emma Qiao
9f044b9dd9
[None][infra] Waive failed tests for main 01/12 (#10604)
Signed-off-by: qqiao <qqiao@nvidia.com>
2026-01-12 10:24:54 -05:00
mpikulski
bf7998f1b8
[TRTLLM-9522][test] cover LLM API multi_modal_embeddings (#9963)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
2026-01-12 11:38:22 +01:00
Wanli Jiang
11da7e3605
[None][fix] Solve pillow version conflict (#10537)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
2026-01-12 04:05:54 -05:00
Zhenhuan Chen
3bd319dc8e
[https://nvbugs/5794796][chore] waive test blocking premerge (#10593)
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
2026-01-12 15:39:07 +08:00
yufeiwu-nv
8e806abac3
[None][test] Remove most TRT-backend test cases in llm_perf_nim.yml (#10572)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2026-01-12 15:34:55 +08:00
yingguo-trt
c5914f9085
[None][chore] update deepseekv3.2 test parameter (#10595)
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
2026-01-12 01:43:22 -05:00
chenfeiz0326
54459377d2
[TRTLLM-10248][feat] Support Bot to Send Perf Regression Msg to Slack Channel (#10489)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
2026-01-12 14:23:23 +08:00
Xianjie Qiao
3a9a00b544
[None][feat] Add ExpertStatistic and DUMMY_ALLREDUCE for configurable_moe (#10401)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
2026-01-12 14:10:31 +08:00
Jie Li
5e0dbba0c9
[None][chore]: update waive list (#10577)
Signed-off-by: Jie Li <lijie@nvidia.com>
2026-01-11 22:18:04 -05:00
TensorRT LLM
2de22f1a70 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-12 03:09:53 +00:00
Pengbo Wang
c0e25e5418
[TRTLLM-10022][feat] Add hopper xqa decode support for skip softmax attention (#10264)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
2026-01-11 19:26:10 -05:00
Eran Geva
c5d5af9e7f
[#8391][chore] removed llama and added deepseek to AutoDeploy's L0 perf test (#10585)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
2026-01-11 16:31:24 -05:00
Ivy Zhang
7f018c89e9
[None][test] update core test list (#10538)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
2026-01-11 14:08:20 -05:00
Yechan Kim
8e0d20d901
[TRTLLM-10195][feat] K-EXAONE support (#10355)
Signed-off-by: Jaedeok Kim <jaedeokk@nvidia.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: Jaedeok Kim <jaedeokk@nvidia.com>
2026-01-12 00:29:51 +09:00
Yanchao Lu
80649a8b78
[None][ci] Workaround OCI-NRT slowdown issue (#10587)
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
2026-01-11 22:08:19 +08:00
Guoming Zhang
0371cbfd88
[None][doc] Update Qwen3-Next doc by adding known issues section (#10582)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2026-01-11 14:47:47 +08:00
TensorRT LLM
b2e2538fcd [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-11 03:07:48 +00:00
HuiGao-NV
3c65ec3c55
[None][chore] waive test case (#10581)
Signed-off-by: Hui Gao <huig@nvidia.com>
2026-01-10 18:53:36 -05:00
fredricz-20070104
f6045fac09
[None][chore] Fix Gitlab CI termination issues (#10576)
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
Co-authored-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
2026-01-10 07:51:18 -05:00
tcherckez-nvidia
f6c4dd885f
[None][chore] Update AutoDeploy model list (#10505)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2026-01-10 08:47:37 +02:00
TensorRT LLM
6ab996d635 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-10 03:09:09 +00:00
William Zhang
ff7eb93f31
[https://nvbugs/5669097][tests] Add MMMU test for mistral small (#10530)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2026-01-09 16:09:28 -08:00
Chenghao Zhang
38f249b479
[https://nvbugs/5548861][fix] AutoDeploy: Fix the test (#10521)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
2026-01-09 13:30:24 -08:00
Linda
82dfef2e56
[https://nvbugs/5628848][fix] Fix nanobind stub generation (#10516)
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
2026-01-09 11:32:21 -08:00
Faraz
fdbdbba540
[https://nvbugs/5752687][fix] Choose register model config over root config for VLM (#10553)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
2026-01-09 12:10:52 -05:00
yingguo-trt
d80f01d205
[None][feat] Add support for DeepSeek v3.2 tests (#10561)
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
2026-01-09 10:20:29 -05:00
Yechan Kim
7295af68ba
[None][fix] Enable AttentionDP on Qwen3-VL and fix test (#10435)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
2026-01-10 00:13:26 +09:00
Kaiyu Xie
1c69aad850
[TRTLLM-10309] [feat] Optimize qk rope/nope concat for DSA (#10571)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2026-01-09 09:50:57 -05:00
Iman Tabrizian
ced88424ef
[https://nvbugs/5756008][fix] unwaive test (#10523)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
2026-01-09 09:40:07 -05:00
Jie Li
627d306df9
[None][chore] remove some model support; add device constraint (#10563)
Signed-off-by: Jie Li <lijie@nvidia.com>
2026-01-09 09:36:23 -05:00
ruodil
2b72d33fdc
[TRTLLM-9932][test] add kimi_k2 single node perf test (#10436)
Signed-off-by: Ruodi Lu <ruodil@users.noreply.github.com>
Co-authored-by: Ruodi Lu <ruodil@users.noreply.github.com>
2026-01-09 05:36:50 -05:00
Fanrong Li
4632a8642d
[None][doc] blog: Optimizing DeepSeek-V3.2 on NVIDIA Blackwell GPUs (#10565)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2026-01-09 05:16:00 -05:00
Yuxian Qiu
80f261ea36
[https://nvbugs/5622938][feat] Run sample_async on extra stream. (#10215)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2026-01-09 18:15:18 +08:00
Chang Liu
78bb245554
[https://nvbugs/5787453][fix] Better align MLA chunking with indexer chunking when chunked prefill enabled for DSV32 (#10552) 2026-01-09 00:49:39 -08:00
bhsueh_NV
4a09acd012
[https://nvbugs/5785206][infra] unwaive the accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B (#10560)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2026-01-09 03:13:29 -05:00
JadoTu
4c498bfe58
[TRTLLM-9676][fix] Fix mamba_cache_manager when enabling cuda_graph_padding and let test cover this case (#9873)
Signed-off-by: JadoTu <107457950+JadoTu@users.noreply.github.com>
2026-01-09 14:50:16 +08:00
Yukun He
c5331e6dbb
[None][fix] Setup dist for AutoTuner in Layerwise benchmarking. (#10534)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2026-01-09 14:16:39 +08:00
Jie Li
6fcd4e7099
[None][chore] Add failed cases into waives.txt (#10541)
Signed-off-by: Jie Li <lijie@nvidia.com>
2026-01-09 01:03:47 -05:00
TensorRT LLM
5df03b2ea7 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-09 03:43:08 +00:00
ruodil
d707286ca8
[None][test] restrict max_num_tokens in disagg mtp config (#10442)
Signed-off-by: Ruodi Lu <ruodil@users.noreply.github.com>
Co-authored-by: Ruodi Lu <ruodil@users.noreply.github.com>
2026-01-08 21:53:24 -05:00
Yuxian Qiu
afa55c12b6
[None][fix] revert https://github.com/NVIDIA/TensorRT-LLM/pull/10445. (#10547)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2026-01-08 21:50:04 -05:00
Balaram Buddharaju
56e779d09f
[None][chore] Waive tests blocking premerge 01/08 (#10555)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2026-01-08 20:22:28 -05:00
Mike Iovine
4092a87b6f
[https://nvbugs/5740075][fix] Fix sm120 speculation (#10049)
Signed-off-by: Mike Iovine <miovine@nvidia.com>
2026-01-08 19:55:43 -05:00
Eran Geva
489dd60312
[#10513][fix] AutoDeploy: removed self.mlp_type leftovers from last moe refactor (#10512)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
2026-01-08 14:49:40 -05:00
mpikulski
e0331297a6
[TRTLLM-9522][fix] broken cast (#9975)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
2026-01-08 06:47:39 -05:00
William Zhang
c0ae6bbdbe
[None][feat] EPD for Qwen3 VL (#10470)
* Why?

We would like to support EPD disaggregated serving for Qwen3 VL.

* What?

This commit adds such support, and extends existing unit tests for
correctness checks.

Some minor (protected) interface changes had to be made to the
weight mapper as a side-effect.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2026-01-08 06:45:54 -05:00
Eran Geva
6511dbaea0
[#10417][fix] AutoDepoloy - Reverted to direct computation of minusA (#10509)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
2026-01-08 13:43:41 +02:00
bhsueh_NV
bea61bb17d
[None][fix] Mistral large 3 few code refine (#10405)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2026-01-08 06:38:49 -05:00
Yiqing Yan
dc6b743fb6
[None][chore] Bump version to 1.2.0rc8 (#10542)
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
2026-01-08 04:51:44 -05:00
Emma Qiao
43839c7d9b
[TRTLLM-9642][infra] Increase pytest verbosity for failed tests (#9657)
Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Emma Qiao <qqiao@nvidia.com>
2026-01-08 02:33:48 -05:00
dongfengy
8d4b09dac6
[None][doc] Update GPTOSS Doc (#10536)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
2026-01-08 02:30:53 -05:00
HuiGao-NV
22c81cb5fa
[None][chore] Enable seg fault cases since one race condition is fixed (#10398)
Signed-off-by: Hui Gao <huig@nvidia.com>
2026-01-08 02:15:30 -05:00
Barry Kang
f57aab5255
[https://nvbugs/5775402][fix] Fix concurrency list in Wide-EP perf tests (#10529)
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
2026-01-08 01:58:55 -05:00
Lucas Liebenwein
30f8455d29
[https://nvbugs/5747878][fix] unwaive llama4 scout tests (#10468)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2026-01-07 23:33:45 -05:00
TensorRT LLM
342a47bf47 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-08 03:12:45 +00:00
yingguo-trt
f8b2a8fd30
[None][chore] Support multiple job submission at the same time (#10492)
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
Co-authored-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
2026-01-07 21:51:36 -05:00
Yuxian Qiu
b85c447ceb
[https://nvbugs/5784543][fix] Setup dist before using autotuner. (#10491)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2026-01-08 10:32:50 +08:00
Yukun He
09d9878385
[TRTLLM-9661][chore] Further reduce tuning time for cuteDSL nvFP4 dense gemm. (#10339)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2026-01-08 10:21:02 +08:00
xxi
81f878c279
[https://nvbugs/5707392][fix] unwaive test_fused_moe_fp8_blockwise_wide_ep[NotEnabled] (#10428)
Signed-off-by: xxi <xxi@nvidia.com>
2026-01-08 09:17:59 +08:00
Lucas Liebenwein
d736c7f290
[https://nvbugs/5761665][fix] AutoDeploy: handle bugs for 25.12 dlfw upgrade (#10511)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2026-01-07 20:16:53 -05:00
Ziyi Xiong
7187afe7b9
[https://nvbugs/5781589][fix] Skip spec dec for non-last rank (#10445)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
2026-01-07 13:55:45 -05:00
Patrice Castonguay
e8cceb06b2
[None][doc] Adding parallelism types in feature combination matrix (#9849)
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
2026-01-07 12:52:05 -05:00
yufeiwu-nv
b130d58c88
[None][test] Remove most TRT-backend test cases in llm_perf_nim.yml (#10487)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2026-01-07 17:18:43 +08:00
tcherckez-nvidia
7e88212d24
[None][bug] fix export for microsoft/Phi-3-medium-128k-instruct (#10455)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2026-01-07 10:30:24 +02:00
xinhe-nv
872210468b
[TRTLLM-8638][fix] Add failed cases into waives.txt (#10474)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2026-01-07 03:23:43 -05:00
Kanghwan
dc32bac9fc
[#4745][fix] Pass lora_params through Qwen2/3 model forward (#10174)
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
2026-01-07 15:30:17 +08:00
yingguo-trt
cbf8357e5f
[https://nvbugs/5726086][fix] update kimi-k2-1k1k dataset (#10473)
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
2026-01-07 01:24:08 -05:00
xinhe-nv
be5579633e
[TRTLLM-8638][fix] Add failed cases into waives.txt (#10457)
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
2026-01-07 00:57:03 -05:00
Fanrong Li
a34aa63685
[https://nvbugs/5767223][feat] add pp support for DeepSeek-v3.2 (#10449)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2026-01-07 12:29:51 +08:00
TensorRT LLM
3fec7e411c [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-07 03:10:22 +00:00
xinhe-nv
1fbadd2dde
[None][chore] Add failed cases into waives.txt (#10365)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Jie Li <lijie@nvidia.com>
Signed-off-by: Jie Li <76780849+jieli-matrix@users.noreply.github.com>
Co-authored-by: Jie Li <lijie@nvidia.com>
Co-authored-by: Jie Li <76780849+jieli-matrix@users.noreply.github.com>
2026-01-06 22:08:06 -05:00
Ivy Zhang
4a1b2e23b3
[https://nvbugs/5698434][test] add qwen3-4b accuracy test case (#10382)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
2026-01-06 21:56:34 -05:00
Lucas Liebenwein
6095c80e56
[https://nvbugs/5721907][fix] AutoDeploy: improve numerical stability of flashinfer attention test (#10467)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2026-01-06 21:11:06 -05:00
Zongfei Jing
bb2f883296
[None] [feat] Add test script and raster M for gather fc1 kernel (#10429)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
2026-01-07 09:31:49 +08:00
Lucas Liebenwein
bb6a3973aa
[https://nvbugs/5732942][fix] AutoDeploy: handle transformers 4.57.1 upgrade fixes (#10466)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2026-01-06 19:55:49 -05:00
Lucas Liebenwein
00355b24b7
[None][feat] precompiled installation from local src dir with fnmatch only (#10430)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2026-01-06 15:31:59 -05:00
Mike Iovine
77be1b7572
[https://nvbugs/5749988][fix] Remove redundant qwen3 spec dec test (#10387)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
2026-01-06 11:46:34 -05:00
Enwei Zhu
037753f65b
[https://nvbugs/5748600][ci] Unwaive disagg guided decoding test (#10409)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2026-01-06 11:38:12 -05:00
Lizhi Zhou
6a4bebcd01
[None][chore] remove redundant retries while binding to arbitrary port (#10452)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
2026-01-06 10:39:15 -05:00
JunyiXu-nv
7d62773c6c
[https://nvbugs/5760726][fix] Use random port in container port section (#10432)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
2026-01-06 23:25:46 +08:00
xinhe-nv
704f58dfbe
[TRTLLM-8638][fix] Add failed cases into waives.txt (#10427)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2026-01-06 04:47:54 -05:00
Emma Qiao
6507087c3f
[None][infra] Waive failed cases on 1/6 (#10440)
Signed-off-by: qqiao <qqiao@nvidia.com>
2026-01-06 16:54:54 +08:00
Bo Li
df0b976b99
[https://nvbugs/5785206][infra] Waive TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]. (#10441)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2026-01-06 03:32:19 -05:00
William Zhang
ab58d7cac1
[https://nvbugs/5772361][ci] Unwaive tests that have been fixed (#10424)
These tests were all failing due to the same issue, and were fixed
in #10394.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2026-01-05 23:49:54 -08:00
Kaiyu Xie
2eaabd7461
[None] [fix] Fix undefined tokens_per_block (#10438)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2026-01-06 02:42:37 -05:00
Ivy Zhang
1e828587e5
[TRTLLM-9896][test] add vswa test cases coverage (#10146)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
2026-01-06 02:02:29 -05:00
Yiqing Yan
5108a69fc0
[TRTLLM-9622][infra] Enable DGX_B300 multi-gpu testing in pre-merge pipeline (#9699)
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
2026-01-06 14:39:55 +08:00
xinhe-nv
998527724c
[TRTLLM-8638][fix] Add failed cases into waives.txt (#10367)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2026-01-06 01:09:21 -05:00
Kaiyu Xie
810249c304
[https://nvbugs/5769926] [fix] Add no container mount home WAR (#10431)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2026-01-06 13:09:25 +08:00
Ivy Zhang
22a1d31a27
[None][test] update test case constraint (#10381)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
2026-01-06 12:28:59 +08:00
xinhe-nv
1b1058279c
[TRTLLM-8638][fix] Add failed cases into waives.txt (#10384)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2026-01-05 23:02:27 -05:00
kris1025
3e98265682
[None][chore] unwaive qwen3 30b test (#10115)
Signed-off-by: linquanh <linquanh@nvidia.com>
2026-01-06 11:17:08 +08:00
TensorRT LLM
596d4f16fb [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-06 03:16:01 +00:00
Karthik
617f728903
[#8460][feat] Revive and simplify Model Explorer visualization integration (#10150)
Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
2026-01-05 22:15:25 -05:00
Venky
aa1fe931de
[None][docs] Add --config preference over --extra_llm_api_options in CODING_GUIDELINES.md (#10426)
Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
2026-01-05 22:05:47 -05:00
Xiao Xuan
46f035befe
[#2511][fix] eagle: qwen2 capture hidden states (#10091)
Signed-off-by: SpicyNoodle <522169030@qq.com>
2026-01-05 21:46:41 -05:00
Min Yu
9cae7277ea
[https://nvbugs/5726962][feat] Apply fusion for W4AFP8_AWQ MoE (#9838)
Signed-off-by: Min Yu <171526537+yumin066@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
2026-01-06 10:16:41 +08:00
alel
6b8ae6fa81
[None][feat] CuteDSL MOE FC1 Enhancement (#10088)
Signed-off-by: Yuhan Li <51736452+liyuhannnnn@users.noreply.github.com>
2026-01-06 09:30:43 +08:00
Mike Iovine
77712ed4ab
[None][chore] Update SWA + spec dec support matrix (#10421)
Signed-off-by: Mike Iovine <miovine@nvidia.com>
2026-01-05 20:26:23 -05:00
JadoTu
82aaf98070
[None][feat] add the eos tokens in generation config to stop words in the sampler (#10389)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
2026-01-06 09:24:03 +08:00
chenfeiz0326
8a04c05079
[None][fix] Only Use Throughput Metrics to Check Regression (#10404)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
2026-01-06 09:21:15 +08:00
Chuang Zhu
536a8f6a9c
[TRTLLM-9527][feat] Add transferAgent binding (step 1) (#10113)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
2026-01-06 08:40:38 +08:00
Lucas Liebenwein
846e54aa09
[None][feat] precompiled installation from local src dir (#10419)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2026-01-05 19:16:38 -05:00
Simeng Liu
3b56548fcf
[https://nvbugs/5777044][chore] Remove solved bugs from waives.txt (#10422)
Signed-off-by: Simeng Liu <109828133+SimengLiu-nv@users.noreply.github.com>
2026-01-05 16:56:58 -05:00
Karthik
4e50cb5708
[#10170][fix] Add export patch for GraniteMoe MoE models to enable torch.export compatibility (#10169)
Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
2026-01-05 16:13:45 -05:00
Mike Iovine
91ff46d418
[https://nvbugs/5745152][fix] Unwaive gpt oss spec decode test (#10370)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
2026-01-05 16:06:58 -05:00
Mike Iovine
7a2dab8e85
[https://nvbugs/5695984][fix] Unwaive llama3 eagle test (#10092)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
2026-01-05 16:03:35 -05:00
Yan Chunwei
6b71b03947
[TRTLLM-9551][infra] Partition test_llm_pytorch.py for parallel execution (#10400)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2026-01-05 13:58:03 -05:00
Grzegorz Kwasniewski
ea380ff45c
[TRTLLM-9767][feat] Fixed recursive node traversals (#10379)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
2026-01-05 18:42:06 +02:00
Mike Iovine
db2614ef10
[https://nvbugs/5772414][fix] Fix draft token tree depth=1 corner case (#10385)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
2026-01-05 17:20:14 +01:00
Mike Iovine
bedfff4f00
[https://nvbugs/5772521][fix] Fix draft token tree chain crash (#10386)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
2026-01-05 17:18:44 +01:00
Gal Hubara-Agam
e98c27ee4f
[TRTLLM-10053][feat] AutoDeploy: Add Super v3 config file, improve test runtime (#10397)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
2026-01-05 18:17:27 +02:00
Anthony Chang
225d3a9001
[None][perf] TRTLLM MoE maps to lower tuning buckets when ep>1 (#9998)
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
2026-01-05 17:16:12 +01:00
Balaram Buddharaju
a792c23dcf
[TRTLLM-9465][fix] Swap TP-CP grouping order (#10350)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2026-01-05 20:08:03 +08:00
Eran Geva
3749a2ce1c
[#10374][fix] fixed race condition in AutoDeploy's mp tests port acquisition (#10366)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
2026-01-05 13:33:01 +02:00
xinhe-nv
b1733d56f6
[TRTLLM-9381][test] add disag-serving kimi k2 thinking tests (#10357)
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
2026-01-05 05:15:52 -05:00
Fanrong Li
4931c5eb3a
[None][feat] update deepgemm to the DeepGEMM/nv_dev branch (#9898)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2026-01-05 16:43:42 +08:00
Yukun He
d272f1a9bc
[TRTLLM-8821][feat] Apply AutoTuner to AllReduce Op for strategy tuning. (#8531)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2026-01-05 15:44:37 +08:00
HuiGao-NV
2f768b76f8
[https://nvbugs/5715568][fix] Force release torch memory when LLM is destroyed (#10314)
Signed-off-by: Hui Gao <huig@nvidia.com>
2026-01-05 15:30:18 +08:00
Emma Qiao
c63fad7d96
[None][infra] Waive failed cases again on 1/5 (#10403)
Signed-off-by: qqiao <qqiao@nvidia.com>
2026-01-05 02:12:16 -05:00
Yihan Wang
e7a4486294
[https://nvbugs/5752521][fix] Unwaive test_trtllm_flashinfer_symbol_collision.py (#10227)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2026-01-05 14:37:05 +08:00
Pengyun Lin
c04cf4334e
[TRTLLM-8242][feat] Add stability tags for serve subcommand (#10012)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2026-01-05 14:16:15 +08:00
Yukun He
0937df2c68
[TRTLLM-10185][feat] AutoTuner Cache: Support cache file lock and merge all ranks into one (#10336)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2026-01-05 13:44:09 +08:00
Emma Qiao
5a8bfcbb50
[None][infra]Waive failed cases in post-merge on 1/5 (#10399)
Signed-off-by: qqiao <qqiao@nvidia.com>
2026-01-05 12:30:10 +08:00
Tailing Yuan
a7fe043b13
[None][feat] Layer-wise benchmarks: support TEP balance, polish slurm scripts (#10237)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
2026-01-05 11:23:04 +08:00
TensorRT LLM
aaf80be0f3 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-05 03:19:07 +00:00
Yuxian Qiu
5773a4d775
[https://nvbugs/5701425][chore] Unwaive tests. (#10269)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2026-01-05 09:54:26 +08:00
Cheng Hang
656c705ff1
[None][feat] sm100 weight-only kernel (#10190)
Signed-off-by: Cheng Hang <chang@nvidia.com>
2026-01-05 09:44:36 +08:00
Fanrong Li
b5a1e10bc0
[https://nvbugs/5779534][fix] fix buffer reuse for CUDA graph attention metadata (#10393)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2026-01-05 09:43:44 +08:00
Wanli Jiang
da0830670a
[TRTLLM-10065][feat] Add accuracy tests for super-v3 with multiple-gpus (#10234)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
2026-01-05 09:41:49 +08:00
Lizhi Zhou
82c1ba84a7
[https://nvbugs/5649010][fix] use 0 port as arbitrary port when disagg service discovery is enabled (#10383)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
2026-01-05 09:40:40 +08:00
bhsueh_NV
0517b62789
[https://nvbugs/5772363][fix] fix bug of Mistral-Small-3.1-24B-Instruct-2503 (#10394)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2026-01-05 09:04:13 +08:00
Faraz
8e2065b4d9
[https://nvbugs/5670469][fix] Filter 0s and choose min of kv_head for Nemotron model (#10206)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
2026-01-05 08:42:53 +08:00
Eran Geva
e2f5455533
[#8391][chore] added deepseek_r1_distill_qwen_32b AutoDeploy perf test to L0 (#10377)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
2026-01-04 20:35:52 +02:00
chenfeiz0326
a65b0d4efa
[None][fix] Decrease Pre Merge Perf Tests (#10390)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
2026-01-04 12:21:34 -05:00
Yanchao Lu
c4f27fa4c0
[None][ci] Some tweaks for the CI pipeline (#10359)
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
2026-01-04 11:10:47 -05:00
dongfengy
afc533193d
[None][feat] Support nvfp4 for gptoss (#8956)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
2026-01-04 08:57:44 -05:00
Jaedeok Kim
a4dcc6a711
[TRTLLM-10171][fix] Correct attention handling in ModelConfig and KVCacheManager (#10330)
Signed-off-by: Jaedeok Kim <jaedeokk@nvidia.com>
2026-01-04 06:07:30 -05:00
Yuxian Qiu
6ba04eba06
[https://nvbugs/5748683][fix] Use get_free_port_in_ci to avoid port conflict. (#10392)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2026-01-04 19:04:58 +08:00
TensorRT LLM
71b4a8aa60 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-04 03:08:01 +00:00
yuanjingx87
5bd37ce41e
[None][infra] add retry logic to get slurm sbatch job log when ssh dropped (#9167)
Signed-off-by: Yuanjing Xue <197832395+yuanjingx87@users.noreply.github.com>
2026-01-04 10:11:37 +08:00
Grzegorz Kwasniewski
0d1f5ad7a2
[TRTLLM-10358][feat] Added proper rescaling of FP4 weights (#10378)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
2026-01-03 16:26:16 -05:00
Yanchao Lu
c0b3c2b919
[None][ci] Remove an invalid test waive
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
2026-01-03 23:34:13 +08:00
Ludwig Schneider
59045a0e41
[None][fix] [fix] Make NCCL resource manager destructor exception-safe (#10166)
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
2026-01-03 10:25:05 -05:00
Emma Qiao
865992b86b
[None][infra] Waive failed cases on 1/3 (#10391)
Signed-off-by: qqiao <qqiao@nvidia.com>
2026-01-03 05:54:09 -05:00
Bo Deng
9e7b50aefb
[TRTLLM-9752][fix] WAR: Disable PDL for quant kernels to fix accuracy issues (#10285)
Signed-off-by: Bo Deng <deemod@nvidia.com>
2026-01-03 14:34:55 +08:00
TensorRT LLM
45ffbf1f21 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-03 03:07:50 +00:00
Lucas Liebenwein
937f8f78a1
[None][doc] promote AutoDeploy to beta feature in docs (#10372)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2026-01-02 18:46:31 -05:00
Izzy Putterman
bdf6953ddc
[None][feat] Eagle: MLA Based Eagle (#9677)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
2026-01-02 13:45:07 -05:00
Gal Hubara-Agam
f3dd6da080
[#10056][chore] AutoDeploy: Enable Nemo SuperV3 accuracy test (#10308)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
2026-01-02 11:20:19 +02:00
chenfeiz0326
5e0e48144f
[None][fix] Minor updates on Perf Test System (#10375)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
2026-01-02 17:17:42 +08:00
TensorRT LLM
098251648d [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-02 03:11:08 +00:00
fredricz-20070104
f631b25c85
[None][test] Unified slurm extra args management and session collection logic (#10332)
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
Co-authored-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
2026-01-01 21:10:51 -05:00
Balaram Buddharaju
4a1b742aa0
[TRTLLM-9467][fix] Fix PP+CP combination with helix parallelism (#10312)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2026-01-01 13:42:53 -05:00
Gal Hubara-Agam
5845951538
[#10056][fix] AutoDeploy: Handle deletion of nested params in sharding (#10376)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
2026-01-01 08:11:11 -05:00
tcherckez-nvidia
4868772ad7
[None][feat] Add export data to build and run script for AD (#10299)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2026-01-01 04:54:47 -05:00
Balaram Buddharaju
9f5b750a93
[None][chore] Waive tests blocking pre-merge 12/31 (#10373)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2026-01-01 03:00:24 -05:00
Balaram Buddharaju
0b75340223
[https://nvbugs/5744427][fix] Make Gemma3 multimodal test fp8 (#10368)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2026-01-01 01:11:34 -05:00
TensorRT LLM
edbcff0257 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2026-01-01 03:08:31 +00:00
Yuxian Qiu
ff836d4f41
[https://nvbugs/5740359][chore] Unwaive tests. (#10260)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2026-01-01 09:53:34 +08:00
Lucas Liebenwein
1bbe71b3ed
[#10244][feat] AutoDeploy: separate prefill/decode in flashinfer (#10252)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2025-12-31 17:01:24 -05:00
Mike Iovine
9085021aa4
[None][feat] Implement sampling for MTP 1-model (#10019)
Signed-off-by: Mike Iovine <miovine@nvidia.com>
2025-12-31 13:48:34 -05:00
Simeng Liu
84d107b2f0
[https://nvbugs/5717993][fix] Add execution_stream across PyExecutor, KVCacheManager, PeftCacheManager to ensure proper CUDA stream synchronization between KV cache transfer operations and model forward kernels. (#10060)
Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
2025-12-31 09:22:54 -08:00
xinhe-nv
0d2e2718ce
[None][chore] Add failed cases into waives.txt (#10354)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-31 09:30:22 -05:00
chenfeiz0326
a23c6f1092
[TRTLLM-9834][feat] Transfer to TRTLLM-INFRA Database and Fail post-merge tests if regression (#10282)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
2025-12-31 21:44:59 +08:00
tcherckez-nvidia
464847c6be
[#9717][chore] Standardize MoE weights interface (#10295)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2025-12-31 07:37:18 -05:00
Jin Li
ef1d4a40b5
[https://nvbugs/5727475][fix] Avoid use property with setter in nn.Mo… (#10212)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-12-31 06:21:36 -05:00
Emma Qiao
d944430f96
[None][infra] Waive failed cases on 12/31 (#10353)
Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-31 17:39:49 +08:00
Necofish
73870ae4ad
[None][feat] support Qwen3-VL dense model in pytorch backend (#9060)
Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
2025-12-31 17:54:26 +09:00
xinhe-nv
827d12caaf
[https://nvbugs/5558516][test] add disaggregated stress test (#9354)
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
2025-12-31 16:47:36 +08:00
Yuxian Qiu
910a633066
[https://nvbugs/5774869][chore] waive tests. (#10356)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-31 03:00:52 -05:00
Yiqing Yan
fdc03684cc
[TRTLLM-10016][infra] Use SlurmPatition attribute time as timeout threshold (#10254)
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-31 15:02:24 +08:00
Pengyun Lin
fad000589d
[None][chore] Unify DS tool parser names (#10239)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-12-31 14:40:07 +08:00
xinhe-nv
1e9c153b4c
[None][fix] disable thread leak check for kimi (#10337)
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
2025-12-31 01:31:37 -05:00
xinhe-nv
6c1abf2d45
[None][chore] Add failed cases into waives.txt (#10344)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-31 00:11:54 -05:00
TensorRT LLM
ed3a3097a4 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-31 03:11:56 +00:00
Jin Li
34c2fd50a9
[https://nvbugs/5707359][fix] Unwaive OOM case that should be fixed by #9446 (#10334)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-12-31 10:41:39 +08:00
Yuxian Qiu
1f3afb8e6f
[None][feat] Implement send_object for TorchDist. (#10213)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-31 10:40:52 +08:00
Yuxian Qiu
ec8a388c25
[https://nvbugs/5769890][fix] Import get_free_port. (#10341)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-31 09:47:27 +08:00
Eran Geva
74832a1895
[https://nvbugs/5766986][fix] fixed the shard_all_unprocessed default value to align with the default.yml (#10271)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
2025-12-30 08:54:13 -05:00
Bo Li
1f0365da36
[None][infra] Add LongBenchV1 to trtllm-eval. (#10265)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2025-12-30 21:39:34 +08:00
Emma Qiao
6732c76414
[None][infra] Waive failed cases for main on 12/30 (#10338)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-30 05:17:43 -05:00
Emma Qiao
fb05cd769a
[None][infra] Enable single-gpu CI on spark (#9304)
Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Emma Qiao <qqiao@nvidia.com>
Signed-off-by: Jenny Liu <JennyLiu-nv+JennyLiu@users.noreply.github.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-30 17:22:14 +08:00
Emma Qiao
cce7247815
[https://nvbugs/5594703][infra] Unwaive the failed case to test (#10275)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-30 16:38:54 +08:00
xinhe-nv
6accdbc6a6
[None][chore] Add failed cases into waives.txt (#10302)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-30 03:11:52 -05:00
ruodil
0f4ed90560
[TRTLLM-9965][test] add long-context disagg test for GB300/GB200 and remove config_index in yaml (#10225)
Signed-off-by: Ruodi Lu <ruodil@users.noreply.github.com>
Co-authored-by: Ruodi Lu <ruodil@users.noreply.github.com>
2025-12-30 02:39:50 -05:00
binghanc
692d8f2023
[TRTLLM-9455][feat] support for new checkpoint (#10082)
Signed-off-by: binghanc <176802681+binghanc@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-12-30 14:46:39 +08:00
xinhe-nv
3e0344a53d
[None][chore] Add failed cases into waives.txt (#10301)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
2025-12-30 14:04:28 +08:00
xinhe-nv
48fee8d0f6
[None][chore] Add failed cases into waives.txt (#10321)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-30 00:11:49 -05:00
Emma Qiao
f396ad83b0
[None][infra] Remove duplicates in waives.txt (#10333)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-29 22:32:52 -05:00
TensorRT LLM
fa4c7997c5 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-30 03:07:48 +00:00
Balaram Buddharaju
4944192eae
[None][chore] Waive tests failing in pre-merge 12/28 (#10311)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-29 20:53:49 -05:00
Neta Zmora
966231d29c
[#9626][feat] Add an auto-deploy transform for using cutlass FP4 MoE kernels (#10304)
Add a transform to relace torch.ops.auto_deploy.torch_quant_nvfp4_moe
with the optimized torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused.

Currently generates the wrong results when the number of rows in MoE FC1 weights is not divisible by 128,
so torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused is not set as the default FP4 MoE implementation (i.e. the transform is disabled).

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
2025-12-29 23:18:15 +02:00
Yanchao Lu
965578ca21
[None][infra] Some improvements for Slurm execution path in the CI (#10316)
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-29 06:49:44 -05:00
Yueh-Ting (eop) Chen
9cee32ab39
[https://nvbugs/5625990][fix] Respect VSWA scheme when doing block store for reuse and load block for reuse in KV cache manager (#10183)
Signed-off-by: eopXD <yuehtingc@nvidia.com>
2025-12-29 14:29:14 +08:00
Yanchao Lu
2f8d6d25a8
[None][ci] Waive an intermittent test hang case (#10324)
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-29 13:04:31 +08:00
TensorRT LLM
223411e988 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-29 03:08:32 +00:00
Yanchao Lu
270be801aa
[None][ci] Move remaining DGX-B200 tests to LBD (#9876)
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-28 13:55:39 +08:00
Ziyi Xiong
c59aa8bec5
[TRTLLM-9962][feat] Some optimizations for two-model spec dec (#10208)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
2025-12-28 12:52:04 +08:00
TensorRT LLM
ae6d5766ed [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-28 03:11:53 +00:00
JunyiXu-nv
55bc6a5ff8
[https://nvbugs/5753250][fix] Fix undefined local variable in responses utils (#10154)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
Signed-off-by: JunyiXu-nv <219237550+JunyiXu-nv@users.noreply.github.com>
2025-12-28 06:59:32 +08:00
shivghai
ee07a7c55e
[None][fix] [Gemma3] Fix RoPE for local attention for Gemma3 (#9961)
Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com>
2025-12-27 11:50:59 -08:00
Guoming Zhang
1865020b6f
[TRTLLM-8577][feat] Clean the Qwen3-next code by removing Qwen3NextCo… (#10228)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2025-12-27 22:49:55 +08:00
Guoming Zhang
93ac0bc1dc
[TRTLLM-10126][feat] Increase topk upper limit to 22 for NVLinkOneSid… (#10229)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2025-12-27 22:48:10 +08:00
TensorRT LLM
27976fce9c [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-27 03:08:04 +00:00
Olya Kozlova
55f3cda66d
[None][fix] Fix request_id for best_of/n case (#8368)
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
2025-12-26 22:20:24 +01:00
Jin Li
c04563657e
[TRTLLM-7735][feat] Attention NVFP4 out support for torch compile (#9740)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-12-27 00:07:20 +08:00
chenfeiz0326
d70aeddc7f
[TRTLLM-8952][feat] Support Multi-Node Disagg Perf Test in CI (#9138)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
2025-12-26 22:50:53 +08:00
Pengyun Lin
684b37df02
[https://nvbugs/5747938][fix] Use local tokenizer (#10230)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-12-26 22:08:10 +08:00
Pengyun Lin
c5b0f9e436
[https://nvbugs/5633700][fix] Cache tiktoken vocab for gpt-oss (#10219)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-12-26 18:39:03 +08:00
dongfengy
bfc591994c
[https://nvbugs/5745152][fix] Fix some GPTOSS test setups (#10085)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
2025-12-26 17:52:40 +08:00
Jatin Gangani
4a5ef84dc2
[None] [doc] Document perfect MoE router feature for perf analysis (#10303)
Signed-off-by: Jatin Gangani <jgangani@dc2-container-xterm-014.prd.it.nvidia.com>
Co-authored-by: Jatin Gangani <jgangani@dc2-container-xterm-014.prd.it.nvidia.com>
2025-12-26 04:27:40 -05:00
Wanli Jiang
14554ab3f3
[None][feat] Support multi-gpu running for nemotron-v3-nano and super (#10118)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
2025-12-26 11:23:14 +08:00
TensorRT LLM
819d03fa88 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-26 03:08:15 +00:00
Enwei Zhu
13ffe52ad0
[None][fix] Allow YAML config overwriting CLI args for trtllm-eval (#10296)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-25 15:08:15 -05:00
Neta Zmora
f3f02315df
[None][chore]: small refactoring to auto-deploy MoE operator (#10300)
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
2025-12-25 12:27:11 -05:00
bhsueh_NV
db3430f589
[None][feat] Support VLM part for Mistral Large 3 (#10188)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-12-25 11:20:58 -05:00
Jin Li
7e4cef9def
[None][fix] Cherry-pick conflict changes for PR 7999 PR 8515 (#9446)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-12-25 10:23:04 -05:00
Ziyi Xiong
d8b5aeb061
[https://nvbugs/5652062][fix] Rewind kv_cache and reset draft tokens (#10160)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
2025-12-25 09:13:51 -05:00
ZhichenJiang
46e4af5688
[TRTLLM-9831][perf] Enable 2CTA with autotune for CuteDSL MoE and Grouped GEMM optimizations (#10201)
Signed-off-by: zhichen jiang <zhichenj@NVIDIA.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-25 09:04:20 -05:00
Lizhi Zhou
fe12faef81
[https://nvbugs/5752516][chore] unwaive test; fix port conflicts in CI (#10152)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
2025-12-25 08:16:09 -05:00
Iman Tabrizian
cd5cd60ee4
[None][infra] Move install_boost from install_triton.sh to install_base.sh (#10055)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Signed-off-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
2025-12-25 08:09:55 -05:00
Zhenhuan Chen
8462cf6c96
[TRTLLM-9578][feat] make PDL enabled by default (#9695)
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
2025-12-25 07:15:24 -05:00
Jatin Gangani
97b38ac403
[None] [doc] Update IFB performance guide & GPTOSS deployment guide (#10283)
Signed-off-by: Jatin Gangani <jgangani@dc2-container-xterm-014.prd.it.nvidia.com>
Co-authored-by: Jatin Gangani <jgangani@dc2-container-xterm-014.prd.it.nvidia.com>
2025-12-25 05:52:04 -05:00
Emma Qiao
0ecdb69b93
[None][infra] Waive failed tests for main on 12/25 (#10298)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-25 05:22:39 -05:00
Xianjie Qiao
53b81783b1
[None][fix] Fix pageable H2D memcopy issue on GB200 (#10289)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
2025-12-25 18:15:57 +08:00
Jie Li
83e02ee335
[None][chore] Remove NIM TRT-Backend Test Lists (#10232)
Signed-off-by: Jie Li <lijie@nvidia.com>
2025-12-25 04:01:51 -05:00
Enwei Zhu
182b3eb633
[None][ci] Waive TestLlama3_1_8B::test_auto_dtype[False-2] for timeout (#10293)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-25 02:35:18 -05:00
Gabriel Wu
1d01214ff0
[None][feat] Drop non-deepgemm fp8 block scale gemm (#10256)
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
2025-12-25 14:52:52 +08:00
xinhe-nv
4ae6f6a46c
[None][chore] Add failed cases into waives.txt (#10249)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-25 01:26:21 -05:00
heyuhhh
7395ca93b6
[None][doc] Add Sparse Attention feature doc (#9648)
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2025-12-25 00:26:18 -05:00
Venky
c059e6caa1
[TRTC-121] [feat] Add recipe selector UI to complement the recipe database (#10125)
Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
2025-12-24 23:56:54 -05:00
gramnarayan
a9eb5afc9f
[#9241][feat] AutoDeploy: Support Eagle3 Speculative Decoding (#9869)
Support two model flow with no overlap scheduler or chain drafter. Drafting model is in PyTorch backend.

Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
2025-12-24 23:30:42 -05:00
TensorRT LLM
1f8ed71d5f [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-25 03:08:38 +00:00
Emma Qiao
16fd781e42
[TRTLLM-9862][infra] Move single-gpu tests on rtxpro6000d to pre-merge (#9897)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-24 21:45:33 -05:00
Ziyi Xiong
43178590d1
[TRTLLM-10143][feat] Reuse previous draft requests if possible (#10263)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
2025-12-24 17:48:38 -08:00
Neta Zmora
c4b36d31ff
[#10137][feat] AutoDeploy FP8 MoE refactor (#10138)
The trtllm (cutlass) fp8 moe operator performs W3+W1 fusion (concat) during inference and we want to move this fusion to the model optimization time.

The Cutlass MoE kernel is used thru a trtllm torch operator.
Its implementation uses two FC operations (fc1 and fc2) while the canonical MoE API defines three GEMM operations and their associated weights (W1, W2, W3) so when we switch from the torch.moe op to the trtllm.moe op we also change terminology from w1, w2, w3 to fc1, fc2.

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
2025-12-24 18:58:10 +02:00
Necofish
8614cd3439
[None][fix] fix: resolve GPU memory imbalance in concurrent weight loading (#6472)
Signed-off-by: Necofish <liuxiangyang@mail.ustc.edu.cn>
Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
Signed-off-by: Jie Li <lijie@nvidia.com>
Co-authored-by: Jie Li <lijie@nvidia.com>
2025-12-24 09:43:09 -05:00
Suyog Gupta
e2891a6c77
[#10052][feat] AutoDeploy enable cudagraphs for flashinfer BatchDecode (#10193)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
2025-12-24 05:55:09 -08:00
Stanley Sun
ddac4d7379
[None][test] Add disag-serving auto scaling qa test (#10262)
Signed-off-by: Stanley Sun <stsun@nvidia.com>
2025-12-24 08:43:47 -05:00
Yiqing Yan
69152c4e7c
[None][infra] Check GB200 coherent GPU mapping (#10253)
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
2025-12-24 17:12:36 +08:00
tcherckez-nvidia
56ef97e06e
[#10246][feature] Move AD dashboard to use cudagraph compile backend (#10267)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2025-12-24 11:09:59 +02:00
Jonas Li
ecea71ca7a
[None][chore] Update tinygemm kernel name (#10248)
Signed-off-by: Jonas Li <6110159+longlee0622@users.noreply.github.com>
2025-12-24 02:33:25 -05:00
shuyixiong
f4f0fe85e9
[TRTLLM-9737][chore] Add rl perf reproduce script and enhance the robustness of Ray tests (#9939)
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
2025-12-24 15:27:01 +08:00
xinhe-nv
534700ecd9
[None][chore] Add failed cases into waives.txt (#10240)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-24 02:21:50 -05:00
Yukun He
595daa5089
[TRTLLM-9615][feat] Support synchronization through PP ranks in the distributed tuning system (#10011)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-12-24 15:03:10 +08:00
Fanrong Li
156f6453dc
[TRTLLM-9798][feat] Change to use new DeepGEMM MQA sm100 kernel for MTP-3 (#10226)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2025-12-24 14:39:12 +08:00
zackyoray
f6c3bc16b9
[None][docs] Add NIXL-Libfabric Usage to Documentation (#10205)
Signed-off-by: Yoray Zack <62789610+zackyoray@users.noreply.github.com>
2025-12-23 23:05:40 -05:00
Emma Qiao
7b84e48e0f
[None][infra] Waive failed cases om 12/24 (#10257)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-23 22:49:57 -05:00
TensorRT LLM
68cf5c7924 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-24 03:08:16 +00:00
xinhe-nv
fc1f77eafc
[None][chore] Add failed cases into waives.txt (#10204)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Jie Li <76780849+jieli-matrix@users.noreply.github.com>
Co-authored-by: Jie Li <76780849+jieli-matrix@users.noreply.github.com>
2025-12-24 10:37:23 +08:00
Balaram Buddharaju
8c1cfc872b
[TRTLLM-9493][feat] Custom AllToAll for helix parallelism (#9986)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-23 18:14:30 -08:00
Jhao-Ting Chen
92d90fa29a
[None][feat] Expose enable_trt_overlap in Triton_backend brings 1.05x OTPS (#10018)
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
2025-12-23 11:41:31 -06:00
Grzegorz Kwasniewski
0027a01ad5
[https://nvbugs/5680312][fix] Updated test waiving (#9630)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
2025-12-23 09:38:12 -08:00
Grzegorz Kwasniewski
06900a7f19
[TRTLLM-9565][fix] Fix deepseek sharding (#9984)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
2025-12-23 10:28:14 -05:00
Emma Qiao
984c20e0b2
[None][infra] Waive failed cases on 12/23 (#10236)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-23 08:48:54 -05:00
dongfengy
e284d0bf80
[None][infra] Waive flaky unittest/executor/test_rpc_proxy.py and unittest/executor/test_rpc_worker.py tests (#10209)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-23 07:43:13 -05:00
tcherckez-nvidia
64bb1a5155
[None][chore] Update AD coverage to use torch-cudagraph (#10233)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2025-12-23 07:20:32 -05:00
Roey Azran
8408c40d8b
[https://nvbugs/5702786][fix] Fix race conditions in KV cache communication during unexpected termination (#10076)
Signed-off-by: roeya <165803633+RoeyAzran1992@users.noreply.github.com>
2025-12-23 14:09:51 +02:00
Xianjie Qiao
871c6b435c
[None] [feat] skip batch_tokenize_prompts in CustomDataset (#10214)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
2025-12-23 17:40:57 +08:00
Yukun He
522f1d2bc3
[https://nvbugs/5764627][chore] waive the time-out test (#10222)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-12-23 16:36:06 +08:00
Balaram Buddharaju
f2e00a75de
[None][chore] Remove helix test from rtx test list (#10224)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-23 03:07:37 -05:00
Shiyu Li
3ddc9d2b48
[https://nvbugs/5729697][fix] MNNVL Allreduce: use CUDA runtime instead of Macro to get SM version. (#10062)
Signed-off-by: Shiyu Li <shili@nvidia.com>
2025-12-23 16:07:07 +08:00
chenfeiz0326
48c875f8ea
[None][fix] Add OpenSearch URL in slurm_launch.sh for Multinode Perf Sanity Test (#9990)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
2025-12-23 16:02:38 +08:00
Bo Li
cc1323be24
[None][fix] Fix the bug for top_k=10 in NVLinkOneSided AlltoAll. (#10197)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2025-12-23 02:13:37 -05:00
Yiqing Yan
59b05dc0a8
[None][chore] Bump version to 1.2.0rc7 (#10216)
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
2025-12-23 15:07:47 +08:00
Chuang Zhu
53db3b2612
[https://nvbugs/5741884][fix] unwaive disagg sampler (#10189)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
2025-12-23 14:38:07 +08:00
xinhe-nv
77b591f73b
[None][chore] Add failed cases into waives.txt (#10177)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Jie Li <lijie@nvidia.com>
Signed-off-by: Jie Li <76780849+jieli-matrix@users.noreply.github.com>
Co-authored-by: Jie Li <lijie@nvidia.com>
Co-authored-by: Jie Li <76780849+jieli-matrix@users.noreply.github.com>
Co-authored-by: Larry Xu <197874197+LarryXFly@users.noreply.github.com>
2025-12-23 13:43:50 +08:00
Harshini Komali
d691371eaf
[TRTLLM-9091] [feat] Replace GenAI-Perf with AIPerf (#9310)
Signed-off-by: lkomali <lkomali@nvidia.com>
Signed-off-by: Harshini Komali <157742537+lkomali@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-12-23 13:25:55 +08:00
Pamela Peng
5bc7ffe379
[None][test] Add qa tests for RTX 6K (#10210)
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
2025-12-22 22:47:09 -05:00
TensorRT LLM
18f8b22956 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-23 03:10:39 +00:00
fredricz-20070104
621156ad44
[None][chore] Fix GB300 support issues (#10196)
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
Signed-off-by: fredricz-20070104 <226039983+fredricz-20070104@users.noreply.github.com>
2025-12-23 10:42:41 +08:00
Li Min
1e82ff7a0c
[TRTLLM-9989][fix] Fix tvm_ffi aaarch64 issue. (#10199)
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
2025-12-23 10:20:40 +08:00
Yuxian Qiu
696f754ef4
[None][fix] avoid implicit cudaStreamSynchronize in sample_async. (#10120)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-23 10:15:40 +08:00
Tailing Yuan
648196f8ae
[TRTLLM-9432][feat] Reduce synchronization and recompilation for qwen3-next (#9691)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
2025-12-23 10:14:29 +08:00
Faraz
f05af48bca
[https://nvbugs/5747674][fix] Add contiguous() before view() in load_expert_w3_w1_weight and load (#10136)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
2025-12-22 21:03:34 -05:00
Fanrong Li
0d2500c631
[TRTLLM-9677][feat] Support DeepSeek-V3.2 tool parser (#10126)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2025-12-23 08:46:47 +08:00
Grzegorz Kwasniewski
ccc64da287
[TRTLLM-9847][fix] WAR fix hanging fused allreduce. (#10087)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
2025-12-23 00:03:32 +01:00
tcherckez-nvidia
12e1cb8d7e
[#9717][chore] Refactor MoE code to use enums (#9910)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2025-12-22 15:14:56 -05:00
JunyiXu-nv
aaa87abf41
[TRTLLM-7906][feat] Support multiple post process for Responses API (#9908)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
2025-12-22 11:33:34 -05:00
Emma Qiao
ba14a9308e
[None][infra] Waive failed cases on 12/22 (#10200)
Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-23 00:05:45 +08:00
Pengyun Lin
0f308e95f9
[None][chore] Remove logprobs constraint on trtllm-serve pytorch backend (#9911)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-12-22 21:37:22 +08:00
William Zhang
a6a88985cf
[TRTLLM-9409][feat] Pass MRoPE tensors for EPD disagg (#9758)
* Why?

Certain VLMs like the Qwen family need more than just the multimodal
embeddings in the language model, and need MRoPE position IDs and
deltas. Prior to this commit, only the embeddings could be communicated
from the encoder worker to the prefill worker.

* What?

This commit extends the `DisaggregatedParams` to include the MRoPE
information. It also adjusts several pieces of code required to
communicate that between E, P and D workers.

Closes TRTLLM-9409.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2025-12-22 06:32:49 -05:00
Bo Li
472fe497dc
[None][chore] NVLinkOneSided AlltoAll Support zero local_num_tokens. (#9822)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2025-12-22 05:57:12 -05:00
Yan Chunwei
ea6cd76c55
[None][refactor] simplify get_stats and get_kvcache_events with rpc (#9980)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-12-22 18:23:43 +08:00
Perkz Zheng
c87f1a6b39
[https://nvbugs/5503479][fix] update trtllm-gen kernels to address few bugs (#10089)
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
2025-12-22 04:45:33 -05:00
shuyixiong
9e9523c3cc
[https://nvbugs/5762016][chore] Skip a ray test (#10194)
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
2025-12-22 17:06:19 +08:00
JadoTu
7421224d69
[None][fix] NVFP4 linear method's weight and weight_scale padding (#10148)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
2025-12-22 15:00:31 +08:00
xinhe-nv
d30ee8101e
[None][chore] Remove closed bugs (#10182)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-22 01:58:17 -05:00
Yuxian Qiu
237fd0eae4
[https://nvbugs/5666821][chore] unwaive tests. (#9958)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-22 11:39:45 +08:00
TensorRT LLM
f8501f3cc8 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-22 03:08:12 +00:00
Fanrong Li
f0bd60a395
[https://nvbugs/5684820][fix] fix the detokenizer issue for DeepSeek-v3.2 (#10106)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
2025-12-22 10:56:33 +08:00
Jin Li
066b653940
[TRTLLM-9880][feat] Include torch compile tests in QA test list (#10149)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-12-22 10:37:09 +08:00
Yuxian Qiu
2f139ee07e
[https://nvbugs/5701445][chore] unwaive test. (#9949)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-22 10:12:54 +08:00
Chuang Zhu
914dd39127
[None][fix] disable cuda ipc on device without nvlink (L40s) for disagg test (#9735)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
2025-12-22 09:29:24 +08:00
dominicshanshan
d274a4c5d3
[https://nvbugs/5701457][fix] Unwaive ray test. (#10175)
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
2025-12-22 09:25:58 +08:00
Enwei Zhu
5549067966
[None][ci] Waive GPTOSS test case (#10155)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-22 08:50:44 +08:00
Balaram Buddharaju
5266475014
[None][feat] Cudagraph updates for helix parallelism (#10141)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-21 15:21:52 -05:00
shuyixiong
4fc6036276
[https://nvbugs/5702793][fix] Fix view operation on uncontiguous tensor (#10147)
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
2025-12-21 11:47:20 -05:00
bhsueh_NV
cd4b4f43fa
[None][feat] Support Eagle3 on Mistral Large3 (#9971)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-12-21 10:25:45 -05:00
Kaiyu Xie
5a611cb8f5
[None] [feat] Enhancements to slurm scripts (#10112)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-12-21 10:24:56 -05:00
Emma Qiao
aa5dbb7ca5
[None][infra] Waive failed tests for main branch on 12/21 (#10184)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-21 22:23:46 +08:00
xxi
5ae154022a
[TRTLLM-9872][fix] clear the failed test at CI when enalbe_configurab… (#10067)
Signed-off-by: xxi <xxi@nvidia.com>
2025-12-21 08:14:50 -05:00
Eran Geva
b15f987972
[None][chore] removed duplicated test from l0_b200.yml (#10090)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
2025-12-21 11:34:01 +02:00
Bo Li
a66eeab537
[TRTLLM-9805][feat] Skip Softmax Attention. (#9821)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Co-authored-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
2025-12-21 02:52:42 -05:00
Balaram Buddharaju
dcd3f7b5ea
[https://nvbugs/5744427][fix] Fix accuracy test OOM (#10173)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-21 02:03:38 -05:00
TensorRT LLM
6c76148b56 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-21 03:08:20 +00:00
Bo Li
77e37d9dd0
[https://nvbugs/5753250][infra] Further waive all tests in _test_openai_responses.py (#10176)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2025-12-20 10:25:14 -05:00
Enwei Zhu
2ce785f39a
[https://nvbugs/5643631][fix] Fix hostfunc seg fault (#10028)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-20 07:58:43 -05:00
Enwei Zhu
21a93fbf9d
[TRTLLM-9992][perf] Enable PDL for CuteDSL kernels and overlap MoeOutputMemset (#10043)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-20 03:12:41 -05:00
TensorRT LLM
3f25db9d3e [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-20 03:07:30 +00:00
Yuxian Qiu
3b3069b390
[https://nvbugs/5747930][fix] Use offline tokenizer for whisper models. (#10121)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-20 09:42:07 +08:00
Yuxian Qiu
e75331480f
[None][fix] fix draft_lengths for CUDA graph capture. (#10004)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-20 09:04:48 +08:00
Anish Shanbhag
7c82605327
[None][fix] enable KV cache reuse for config database (#10094) 2025-12-19 15:16:56 -08:00
Balaram Buddharaju
bee9051484
[None][chore] Waive timing out pre-merge test (#10167)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-19 17:56:33 -05:00
Gal Hubara-Agam
20b69a982a
[#10056][test] AutoDeploy: Add accuracy test for Nemotron SuperV3 (#10131)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
2025-12-19 13:28:42 -08:00
Chang Liu
5489d188a4
[None][fix] Revert the change and remove device count guard for DSv32 (#9631)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
2025-12-19 15:00:55 -05:00
longcheng-nv
b882393d69
[https://nvbugs/5720357][fix] Fix indice offset overflow in custom Top-K kernel and corresponding UT case (#10027)
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Co-authored-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
2025-12-19 14:58:01 -05:00
Venky
dfa11d810e
[TRTC-102][docs] --extra_llm_api_options->--config in docs/examples/tests (#10005) 2025-12-19 13:48:43 -05:00
JunyiXu-nv
7b71ff6b8a
[https://nvbugs/5722653][fix] Unwaive fixed test (#10157)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
2025-12-19 11:19:20 -05:00
xxi
27e49e2904
[None][fix] waive the failed test test_service_discovery[etcd-load_ba… (#10161)
Signed-off-by: xxi <xxi@nvidia.com>
2025-12-19 06:14:26 -08:00
tcherckez-nvidia
9f6abaf59f
[#9640][feat] Migrate model registry to v2.0 format with composable configs (#9836)
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
2025-12-19 05:30:02 -08:00
xinhe-nv
7b51e3cedb
[TRTLLM-8638][fix] Add failed cases into waives.txt (#10129)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
2025-12-19 17:55:17 +08:00
Emma Qiao
dd8ce68c94
[None][infra] Update waive and waive failed tests for main branch on 12/19 (#10151)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-19 01:20:42 -08:00
Pengyun Lin
ac03915dc3
[TRTLLM-9604][feat] DS R1 & V3.1 tool parser (#10010)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-12-19 17:20:03 +08:00
Chang Liu
31bc14b350
[TRTLLM-9654][feat] Support DeepSeek-V32 chat template (#9814)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
2025-12-19 17:05:38 +08:00
yufeiwu-nv
52cee573ad
[TRTLLM-8830][test] Overlap scheduler enhancement perf test: Add qwen3_0,8b and llama3.1 test cases (#10114)
Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
2025-12-19 17:01:52 +08:00
xinhe-nv
cb0444b1b5
[TRTLLM-8638][fix] Add failed cases into waives.txt (#10132)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Co-authored-by: Larry Xu <197874197+LarryXFly@users.noreply.github.com>
2025-12-19 16:07:56 +08:00
JunyiXu-nv
356ad4fe3a
[https://nvbugs/5722653][fix] Address port conflict by assigning different port section in the same node. (#10035)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
2025-12-19 15:34:04 +08:00
Ziyi Xiong
70b4d282c6
[TRTLLM-7736][feat] Incrementally update the inputs of target and draft models (#9708)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
2025-12-19 15:11:25 +08:00
Larry Xu
48dbc61129
[None][chore] Update CODEOWNERS for test cases and test list (#10119)
Signed-off-by: LarryXFly <197874197+LarryXFly@users.noreply.github.com>
2025-12-19 13:38:21 +08:00
William Zhang
478b6b20a1
[#9230][refactor] Replace nemotron patches with custom model implementation (#9751)
[#9230][refactor] Replace nemotron patches with custom model implementation

* Why?

Patching for nemotron H models was growing out of hand, and made certain
optimizations more complex than they needed to be.

* What?

This commit finally gets rid of them, and replaces them with the custom
model implementation in `modeling_nemotron_h.py`.

Closes #9230
Closes NvBug 5747867

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2025-12-18 19:36:27 -08:00
Balaram Buddharaju
72c5480dfb
[None][chore] Waive test blocking pre-merge 12/18 (#10145)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-18 19:12:05 -08:00
TensorRT LLM
00f70c30a6 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-19 03:11:26 +00:00
Ivy Zhang
9aa40871c2
[TRTLLM-9840][test] switch ucx backend to default backend (#10101)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
2025-12-18 18:54:15 -08:00
TensorRT LLM
a7ac5a6bca [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-19 02:14:37 +00:00
Wangjue Yao
9f283f330b
[None][feat] Support Mooncake transfer engine as a cache transceiver backend (#8309)
Signed-off-by: wjueyao <wyao123@terpmail.umd.edu>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
2025-12-19 10:09:51 +08:00
Chuang Zhu
e0b2a94309
[None][fix] Fix ready signal in NIXL backend (#10000)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
2025-12-19 09:43:40 +08:00
yuanjingx87
2e88c86f10
[None][infra] Fix issue that lock file geneartion will skip dependency with comment (#10144)
Signed-off-by: Yuanjing Xue <197832395+yuanjingx87@users.noreply.github.com>
2025-12-18 17:41:23 -08:00
Yukun He
bd5b3c2ac0
[https://nvbugs/5721912][chore] Unwaive the test (#10108)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-12-19 09:12:25 +08:00
Anish Shanbhag
91a9ae42d2
[TRTC-71][feat] Add regression testing for config database (#9832)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2025-12-18 16:15:38 -08:00
Balaram Buddharaju
799a2ae311
[https://nvbugs/5741331][fix] Fix helix accuracy test (#10021)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-18 15:27:53 -08:00
Chang Liu
a97e411b44
[https://nvbugs/5747911][fix] Use offline data path for the unit test of mmencoder server (#10135)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
2025-12-18 15:19:23 -08:00
Lizhi Zhou
f02782a6f2
[https://nvbugs/5726066][fix] fix auto-scaling related failures (#9845)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Co-authored-by: Emma Qiao <qqiao@nvidia.com>
2025-12-18 16:37:48 -05:00
Enwei Zhu
6fe89ea00f
[TRTLLM-9819][perf] Reuse alltoall workspace for CuteDSL MoE output (#9840)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-18 10:36:38 -08:00
CarstyYou
0b279f4ad4
[https://nvbugs/5456493][feat] Add fp8 bmm on sm120 (#9687)
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
2025-12-18 22:57:20 +08:00
ZhichenJiang
4e55b83101
[None][perf] Add more optimization options for MOE CuteDSL finalized kernel (#10042)
Signed-off-by: zhichen jiang <zhichenj@NVIDIA.com>
2025-12-18 22:49:28 +08:00
Nikita Korobov
3b4f26e4d1
[None][feat] update TRT-LLM Gen MoE for NvFp4 + bias with tileN=256 (#9734)
Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
2025-12-18 11:58:23 +01:00
yuanjingx87
df15be3fad
[None][infra] Fix slurm job does not catch cancelled jobs (#9722)
Signed-off-by: Yuanjing Xue <197832395+yuanjingx87@users.noreply.github.com>
Signed-off-by: yuanjingx87 <197832395+yuanjingx87@users.noreply.github.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
2025-12-18 00:32:43 -08:00
Bo Li
9d7e038bcb
[https://nvbugs/5753250][infra] Waive _test_openai_responses. (#10110)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2025-12-18 00:15:06 -08:00
Emma Qiao
33a90f2dd2
[None][infra] Waive failed cases for main branch on 12/18 (#10105)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-12-17 21:35:45 -08:00
Yuxian Qiu
bec864a78c
[None][fix] avoid ID conversion for non enable_configurable_moe cases. (#10003)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-12-18 13:29:52 +08:00
yuanjingx87
897a38978d
[None][infra] Update allowlist 2025.12.17 (#10097)
Signed-off-by: Yuanjing Xue <197832395+yuanjingx87@users.noreply.github.com>
2025-12-17 21:11:35 -08:00
Wanli Jiang
601c29ca73
[https://nvbugs/5721644][fix] Update tests for nemotron_h (#9993)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
2025-12-18 12:38:02 +08:00
Lucas Liebenwein
76ec820465
[#7532][feat] AutoDeploy: gather logits before lm head (#9962)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
2025-12-17 19:50:13 -08:00
TensorRT LLM
cfe53e7425 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-18 03:23:35 +00:00
xinhe-nv
4a98f190a8
[None][chore] Add failed cases into waives.txt (#10025)
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
2025-12-17 19:13:52 -08:00
xinhe-nv
c1cfb61b1b
[TRTLLM-9381][feat] Add kimi k2 fp4 tests (#9906)
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
2025-12-17 18:15:27 -08:00
TensorRT LLM
50c2b82f24 [None][infra] Check in most recent lock file from nightly pipeline
Signed-off-by: TensorRT LLM <90828364+tensorrt-cicd@users.noreply.github.com>
2025-12-17 23:45:35 +00:00
tburt-nv
27064f95c7
[None][chore] Clarify copyright header guidance (#9882)
Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com>
2025-12-18 06:38:10 +08:00
tburt-nv
5da7879b38
[None][fix] Revert GHA upgrade for blossom-ci workflow (#10095)
Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com>
2025-12-17 15:57:04 -05:00
Chenghao Zhang
22c6e8a424
[None][fix] Autodeploy: fix some legacy flashinfer attention test errors (#9928)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
2025-12-17 12:27:22 -08:00
Salman Chishti
cb5cd4376e
[None][chore] Upgrade GitHub Actions for Node 24 compatibility (#10045)
Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com>
2025-12-17 09:44:09 -08:00
4809 changed files with 85587 additions and 35933 deletions

19
.github/CODEOWNERS vendored
View File

@ -1,5 +1,18 @@
# This file defines code ownership rules for the repository.
## TensorRT-LLM QA
### Integration Tests
/tests/integration/test_lists/qa @NVIDIA/trt-llm-qa
/tests/integration/defs/examples/test_ray.py @NVIDIA/trt-llm-qa-function
/tests/integration/defs/examples/test_redrafter.py @NVIDIA/trt-llm-qa-function
/tests/integration/defs/accuracy @NVIDIA/trt-llm-qa-function
/tests/integration/defs/stress_test @NVIDIA/trt-llm-qa-function
/tests/integration/defs/triton_server @NVIDIA/trt-llm-qa-function
/tests/integration/defs/test_e2e.py @NVIDIA/trt-llm-qa-function
/tests/integration/defs/disaggregated @NVIDIA/trt-llm-qa-serving
/tests/integration/defs/sysinfo @NVIDIA/trt-llm-qa-perf
/tests/integration/defs/perf @NVIDIA/trt-llm-qa-perf
/tests/integration/defs/perf/disagg @NVIDIA/trt-llm-qa-serving
## TensorRT-LLM Infra
### CI
@ -13,6 +26,11 @@
## TensorRT-LLM - Docs
/docs @NVIDIA/trt-llm-doc-owners
/CODING_GUIDELINES.md @NVIDIA/trt-llm-doc-owners
/CODE_OF_CONDUCT.md @NVIDIA/trt-llm-doc-owners
/CONTAINER_SOURCE.md @NVIDIA/trt-llm-doc-owners
/CONTRIBUTING.md @NVIDIA/trt-llm-doc-owners
/README.md @NVIDIA/trt-llm-doc-owners
## Examples
/examples @NVIDIA/trt-llm-doc-owners
@ -183,6 +201,7 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
## and license compliance when adding, removing, or changing versions of dependencies.
### License Files
/LICENSE @NVIDIA/trt-llm-oss-compliance
/ATTRIBUTIONS-*.md @NVIDIA/trt-llm-oss-compliance
/jenkins/license_cpp.json @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance
### Python Dependency Management

View File

@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
uses: actions/checkout@v6
- name: Get assignee
uses: actions/github-script@v6
uses: actions/github-script@v8
id: get-assignee
with:
github-token: ${{secrets.GITHUB_TOKEN}}

View File

@ -14,7 +14,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v9
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Issue has not received an update in over 14 days. Adding stale label.'

View File

@ -53,6 +53,7 @@ jobs:
"amukkara",
"anish-shanbhag",
"arekay",
"arysef",
"atrifex",
"Autumn1998",
"baize97",
@ -121,6 +122,7 @@ jobs:
"heyuhhh",
"hijkzzz",
"hlu1",
"hnover-nv",
"HuiGao-NV",
"hvagadia",
"hypdeb",
@ -154,6 +156,7 @@ jobs:
"kaiyux",
"kanghui0204",
"karljang",
"karthikvetrivel",
"katec846",
"Kefeng-Duan",
"KingsleyLiu-NV",
@ -191,6 +194,7 @@ jobs:
"mlefeb01",
"moraxu",
"MrGeva",
"mzweilz",
"Naveassaf",
"nekorobov",
"netanel-haber",
@ -215,6 +219,7 @@ jobs:
"omera-nv",
"pamelap-nvidia",
"pcastonguay",
"pcicotti",
"pdrake-nv",
"peaceh-nv",
"pengbowang-nv",
@ -243,6 +248,7 @@ jobs:
"schetlur-nv",
"shaharmor98",
"shangz-ai",
"sherry-1001",
"shifangx",
"Shixiaowei02",
"Shunkangz",
@ -262,6 +268,7 @@ jobs:
"syuoni",
"Tabrizian",
"talorabr",
"taylor-yb-lee",
"tburt-nv",
"tcherckez-nvidia",
"thorjohnsen",

View File

@ -36,7 +36,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Add bot help comment
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
script: |
const helpMessage = "" +

View File

@ -34,7 +34,7 @@ jobs:
if: github.event_name == 'workflow_dispatch'
steps:
- name: Update commit status
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
script: |
state = 'pending'
@ -60,7 +60,7 @@ jobs:
with:
paths: results/**/results*.xml
- name: Update commit status
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
script: |
github.rest.repos.createCommitStatus({

View File

@ -17,10 +17,10 @@ jobs:
if: github.repository == 'NVIDIA/TensorRT-LLM'
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v6
with:
python-version: '3.x'

View File

@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout private action repository
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
repository: NVIDIA/goggles_action
path: ./.github/actions/goggles_action # local path to store the action

View File

@ -59,10 +59,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'

View File

@ -29,11 +29,11 @@ jobs:
name: Pre-commit Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
ref: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.ref || github.ref }}
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
python-version: '3.12'
cache: 'pip'

6
.gitignore vendored
View File

@ -40,6 +40,8 @@ tensorrt_llm/libs
tensorrt_llm/bindings.*.so
tensorrt_llm/bindings.pyi
tensorrt_llm/bindings/**/*.pyi
tensorrt_llm/tensorrt_llm_transfer_agent_binding.*.so
tensorrt_llm/tensorrt_llm_transfer_agent_binding.pyi
tensorrt_llm/deep_ep/
tensorrt_llm/deep_ep_cpp_tllm.*.so
tensorrt_llm/deep_ep_cpp_tllm.pyi
@ -56,13 +58,14 @@ tensorrt_llm/scripts
docs/source/**/*.rst
!docs/source/examples/index.rst
!docs/source/deployment-guide/config_table.rst
!docs/source/deployment-guide/note_sections.rst
!docs/source/_includes/note_sections.rst
*.swp
# Testing
.coverage.*
results_trt/
llm-test-workspace/
ad-test-workspace/
# build/debug
*.safetensors
@ -76,6 +79,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
.devcontainer/.env
/examples/layer_wise_benchmarks/autotuner_cache/
/examples/layer_wise_benchmarks/profiles/
# User config files

View File

@ -38,8 +38,8 @@ FetchContent_Declare(
FetchContent_Declare(
deepgemm
GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM
GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch
GIT_SUBMODULES_RECURSE
ON
SOURCE_SUBDIR

View File

@ -487,9 +487,17 @@ else:
f.read()
```
## Documentation Guidelines
#### CLI Options in Documentation
1. When documenting CLI commands for `trtllm-serve`, `trtllm-bench`, `trtllm-eval`, or similar tools, prefer using `--config` over `--extra_llm_api_options` for specifying configuration files.
- `--config` is the preferred, shorter alias for configuration file options.
- Example: `trtllm-serve --model <model_path> --config config.yaml` (preferred)
- Avoid: `trtllm-serve --model <model_path> --extra_llm_api_options config.yaml`
## NVIDIA Copyright
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
```cpp
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

View File

@ -10,7 +10,7 @@ state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.<
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-13.0.0-green)](https://developer.nvidia.com/cuda-downloads)
[![torch](https://img.shields.io/badge/torch-2.9.0-green)](https://pytorch.org)
[![version](https://img.shields.io/badge/release-1.2.0rc6-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-1.2.0rc8-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](https://nvidia.github.io/TensorRT-LLM/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)

View File

@ -68,6 +68,7 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
ON)
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
"Using open sourced Cutlass AR gemm kernel" ON)
option(SKIP_SOFTMAX_STAT "Enable Statistics of Skip-Softmax" OFF)
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
@ -360,6 +361,11 @@ else()
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
endif()
if(SKIP_SOFTMAX_STAT)
add_compile_definitions("SKIP_SOFTMAX_STAT")
message(STATUS "SKIP_SOFTMAX_STAT is enabled")
endif()
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
# be found in
# https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html#index-mcmodel_003dmedium-1

View File

@ -380,6 +380,7 @@ public:
, mBeamWidth(beamWidth)
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
, mNumFrontBlocksRemoved(0)
, mCurrentPrepopulatedPromptLen(std::numeric_limits<SizeType32>::max())
{
auto const numWindowSizes = windowSizeToMetadata.size();
mCacheBlockIds.reserve(numWindowSizes);
@ -500,6 +501,20 @@ public:
return mKvCacheRetentionConfig.getDirectory();
}
[[nodiscard]] SizeType32 getCurrentPrepopulatedPromptLen() const
{
return mCurrentPrepopulatedPromptLen;
}
void setCurrentPrepopulatedPromptLen(SizeType32 currentPrepopulatedPromptLen)
{
TLLM_CHECK_WITH_INFO(currentPrepopulatedPromptLen <= mCurrentPrepopulatedPromptLen,
"currentPrepopulatedPromptLen must be updated non-increasingly due to the "
"assumption that smaller window sizes have shorter or equal"
"currentPrepopulatedPromptLen in WindowSizeManager::loadOrAllocateBlocks.");
mCurrentPrepopulatedPromptLen = currentPrepopulatedPromptLen;
}
private:
// Request id of the sequence
LlmRequest::RequestIdType mRequestId;
@ -517,6 +532,8 @@ private:
SizeType32 mNumFrontBlocksRemoved;
// Set of used blocks by the sequence
std::set<KVCacheBlock::IdType> mUsedBlocks;
// Current prepopulated prompt length
SizeType32 mCurrentPrepopulatedPromptLen;
};
// attach metadata to a pool pointer
@ -631,7 +648,7 @@ public:
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
@ -836,8 +853,8 @@ public:
//! \param blockKeys Key of each block.
//! \param blockIds Id of each block.
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
bool pinBlocks = false);
@ -869,8 +886,8 @@ public:
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
void unpinBlocksById(KVCacheBlock::IdType blockId);
//! \brief Unpin blocks by block ids directly
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{
@ -1086,7 +1103,7 @@ public:
std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
@ -1095,7 +1112,7 @@ public:
/// @param sequence The generation request whose blocks should be pinned.
void pinBlocks(GenerationRequest& sequence);
void unpinBlocksById(KVCacheBlock::IdType blockId);
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
@ -1116,7 +1133,7 @@ public:
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
SizeType32 windowSize, bool pinBlocks = false)
{
@ -1567,7 +1584,7 @@ public:
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
/// \brief Store blocks for reuse for a given request id
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
= 0;
@ -1661,7 +1678,7 @@ public:
BlockKey const& blockKey, SizeType32 windowSize)
= 0;
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
};
class KVCacheManager : public BaseKVCacheManager
@ -1922,7 +1939,7 @@ public:
//! \brief Store newest blocks for reuse
void storeNewBlock(LlmRequest const& llmRequest) override;
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
@ -1943,7 +1960,7 @@ public:
void pinBlocks(LlmRequest::RequestIdType requestId) override;
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;

View File

@ -1667,6 +1667,12 @@ public:
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
}
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
{
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
}
[[nodiscard]] bool isTimedOut() const
{
if (!mAllottedTimeMs.has_value())

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/executor/serialization.h"
#include <atomic>
#include <vector>
namespace tensorrt_llm::executor::kv_cache
@ -27,8 +28,9 @@ class CommState;
struct DataContext
{
public:
explicit DataContext(int tag)
explicit DataContext(int tag, std::atomic<bool> const& transferTerminate = sDefaultTransferTerminate)
: mTag{tag}
, mTransferTerminate(transferTerminate)
{
}
@ -37,8 +39,15 @@ public:
return mTag;
}
[[nodiscard]] std::atomic<bool> const& getTransferTerminate() const noexcept
{
return mTransferTerminate;
}
private:
inline static std::atomic<bool> sDefaultTransferTerminate{false};
int const mTag;
std::atomic<bool> const& mTransferTerminate;
};
class Connection

View File

@ -1468,7 +1468,8 @@ public:
DEFAULT = 0,
MPI = 1,
UCX = 2,
NIXL = 3
NIXL = 3,
MOONCAKE = 4
};
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,

View File

@ -274,13 +274,20 @@ private:
std::optional<SyncMessage> mSyncMessage;
};
enum class TransferState : uint8_t
{
kIN_PROGRESS,
kSUCCESS,
kFAILURE,
};
// Data structure for checking the status of active transfer operations.
class TransferStatus
{
public:
virtual ~TransferStatus() = default;
[[nodiscard]] virtual bool isCompleted() const = 0;
virtual void wait() const = 0;
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;
};
struct BaseAgentConfig
@ -288,6 +295,8 @@ struct BaseAgentConfig
std::string mName;
bool useProgThread;
bool multiThread;
bool useListenThread;
unsigned int numWorkers;
};
class BaseTransferAgent
@ -391,6 +400,14 @@ template <typename... Args>
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
return func(std::forward<Args>(args)...);
}
if (backend == "mooncake")
{
auto& loader = DynLibLoader::getInstance();
using CreateMooncakeFuncType = std::unique_ptr<BaseTransferAgent> (*)(BaseAgentConfig const*);
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
return func(std::forward<Args>(args)...);
}
TLLM_THROW("Unknown backend name.");
}

View File

@ -104,12 +104,14 @@ public:
[[nodiscard]] SizeType32 constexpr getTensorParallelRank() const noexcept
{
return mRank % mTensorParallelism;
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
return (mRank % (mTensorParallelism * mContextParallelism)) / mContextParallelism;
}
[[nodiscard]] SizeType32 constexpr getContextParallelRank() const noexcept
{
return (mRank % (mTensorParallelism * mContextParallelism)) / mTensorParallelism;
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
return mRank % mContextParallelism;
}
[[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept

View File

@ -69,6 +69,11 @@ PREPROCESSOR_FLAGS += -DUSE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE
# Do we want to use half accumulation for flash attention
PREPROCESSOR_FLAGS += -DHALF_ACCUMULATION_FOR_FLASH_ATTENTION
# Print the resulted sparsity given threshold in Skip-Softmax attention
# Note: You only need to "python scripts/build_wheel.py -D SKIP_SOFTMAX_STAT=ON ..." to use it inside TRTLLM.
# Turn this on manually only if you want to build&run the unittest (bin/fmha.exe) with SKIP_SOFTMAX_STAT.
# PREPROCESSOR_FLAGS += -DSKIP_SOFTMAX_STAT
# Add FLAGS when generating cubins.
ifdef GENERATE_CUBIN
PREPROCESSOR_FLAGS += -DGENERATE_CUBIN

View File

@ -154,7 +154,9 @@ spec_fields = (
'head_size_v',
'sage_block_sizes',
'output_dtype',
'is_mtp')
'is_mtp',
'enable_skip_softmax',
)
kernel_spec = namedtuple('kernel_spec', spec_fields)
kernel_spec.__new__.__defaults__ = (
1, # ctas_per_head
@ -179,7 +181,9 @@ kernel_spec.__new__.__defaults__ = (
0, # head size of V
None, # sage_block_sizes
None, # output_dtype, same as dtype by default.
False) # use MTP or not
False, # use MTP or not
False, # enable skip softmax
)
generate_cu_trtllm = os.environ.get('GENERATE_CU_TRTLLM',
'False').lower() == 'true'
@ -1435,6 +1439,7 @@ using Ktraits = {kernel_traits_header}
USE_TMA_STORE,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_},
{sage_block_size_q},
{sage_block_size_k},
@ -1458,6 +1463,7 @@ using Ktraits_causal = {kernel_traits_header}
USE_TMA_STORE,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_}>;
using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
@ -1478,6 +1484,7 @@ using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
USE_TMA_STORE && false,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_}>;
using Ktraits_custom_mask = {kernel_traits_header}
@ -1498,6 +1505,7 @@ using Ktraits_custom_mask = {kernel_traits_header}
USE_TMA_STORE && false,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_}>;
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -1835,6 +1843,8 @@ def encode_name(kernel_spec):
if kernel_spec.enable_attn_logit_softcapping:
feature_tags += '_softcapping'
if kernel_spec.enable_skip_softmax:
feature_tags += '_skipSoftmax'
if kernel_spec.sage_block_sizes:
feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}"
if kernel_spec.output_dtype:
@ -2131,6 +2141,8 @@ def get_kernel_code(kspec, kname, lname):
return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats]
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
# needed by warpspec kernels.
fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
@ -2331,6 +2343,8 @@ def get_api_code(specs_names):
f'&& sage_block_size_k == {sage_block_size_k} ' \
f'&& sage_block_size_v == {sage_block_size_v} '
il_check += '&& enable_skip_softmax ' if kspec.enable_skip_softmax else '&& !enable_skip_softmax '
il_check += '&& params.use_int8_scale_max ' if kspec.has_scale_max else '&& !params.use_int8_scale_max '
slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0
@ -2607,6 +2621,7 @@ const bool warp_specialization = launch_params.warp_specialization
const bool use_tma = launch_params.use_tma;
const bool use_flash_attention = launch_params.flash_attention;
const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping;
const bool enable_skip_softmax = launch_params.enable_skip_softmax;
const int attention_input_layout = static_cast<int>(launch_params.attention_input_layout);
// tiled variant uses ldgsts
const bool use_tiled = launch_params.use_granular_tiling;
@ -2785,6 +2800,8 @@ def get_kernel_traits_code(specs_names):
enable_attn_logit_softcapping_flag = pythonBoolean2cpp[
kspec.enable_attn_logit_softcapping]
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
tmp = dict(locals(), **kspec._asdict())
if effective_sm < 90:
@ -2903,7 +2920,8 @@ def get_kernel_traits_code(specs_names):
{input_layout_flag},
__use_tma_store__ /* USE_TMA_STORE */,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag}>;
{return_softmax_stats_flag},
{enable_skip_softmax_flag}>;
printf("%s %d %d %s %d %d\\n",
\"{kname}\",
@ -3062,9 +3080,16 @@ def get_kernel_traits_code(specs_names):
# For now:
# 1. Hopper head_size 128 kernel uses cubins for performance regressions.
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
# 3. For skip-softmax attention feature, we force not to use cubins.
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
def use_cubin_header(sm, head_size, dtype, output_dtype=None):
def use_cubin_header(sm,
head_size,
dtype,
output_dtype=None,
enable_skip_softmax=False):
if enable_skip_softmax:
return False
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
return False
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
@ -3079,7 +3104,8 @@ def get_cubin_header(kernel_traits, specs_names):
launchers_dict = {}
for kspec, fname, lname, kname in specs_names:
if generate_cu_trtllm and not use_cubin_header(
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype):
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype,
kspec.enable_skip_softmax):
continue
name = fname.replace('.', '_')
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
@ -3111,8 +3137,9 @@ def get_cubin_header(kernel_traits, specs_names):
'q_kv_', '').replace('q_paged_kv_', '').replace(
'q_k_v_', '').replace('ws_', '').replace(
'softcapping_',
'').replace('sage_',
'').replace('output_', ''))
'').replace('sage_', '').replace(
'skipSoftmax_',
'').replace('output_', ''))
flash_attention = 'flash_attention' in kname
warp_specialization = 'tma_ws' in kname
toks = tname.split('_')
@ -3209,6 +3236,8 @@ def get_cubin_header(kernel_traits, specs_names):
return_softmax_stats_flag = pythonBoolean2cpp[sm != '90' or (
sm == '90' and '_softmax' in kname)]
enable_skip_softmax_flag = pythonBoolean2cpp['_skipSoftmax' in kname]
# meta_unroll_step
meta_unroll_step = unroll_step if ('_nl' in kname
or '_ws' in kname) else '0'
@ -3235,7 +3264,8 @@ def get_cubin_header(kernel_traits, specs_names):
def get_lname_from_kname(kname: str) -> str:
if use_cubin_header(int(sm), int(head_size), prec.lower(),
output_prec.lower()):
output_prec.lower(),
enable_skip_softmax_flag):
return 'nullptr'
lname = kname.replace('_kernel', '')
mask_types = [
@ -3253,15 +3283,15 @@ def get_cubin_header(kernel_traits, specs_names):
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
'''.format(**locals()) if use_cubin_header(int(sm),
int(head_size), prec.lower(),
output_prec.lower()) else '''\
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
prec.lower(), output_prec.lower(),
enable_skip_softmax_flag) else '''\
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
'''.format(**locals())
else:
code = '''\
@ -3269,7 +3299,7 @@ def get_cubin_header(kernel_traits, specs_names):
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}}}\
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}}}\
'''.format(**locals())
if sm in metadata_v2_dict:
metadata_v2_dict[sm].append(code)
@ -3377,7 +3407,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
bool mAlibiSupported;
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;{launcher_line}
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;{launcher_line}
}} sMhaKernelMetaInfosV2[] = {{
{metadata_v2}
}};
@ -3438,6 +3469,7 @@ static const struct TestMetaV2
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;
}} metaV2[] = {{
{metadata_v2}
}};
@ -3484,7 +3516,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
bool mAlibiSupported;
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;{launcher_line}
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;{launcher_line}
}};
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[];
@ -3580,7 +3613,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
bool mAlibiSupported;
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;{launcher_line}
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;{launcher_line}
}};
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[] = {{
@ -3637,7 +3671,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_
return '\n'.join(lines)
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, false, nullptr},'
result = modify_kernel_line(result, target, new_line)
# make sure only one empty line at the end
@ -3801,7 +3835,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
# Note this will be used in TRT-LLM.
def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
def enumerate_hgmma_flash_warpspec_kernels(specs,
sm=90,
dtype='fp16',
enable_skip_softmax=False):
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
@ -3851,7 +3888,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=return_softmax,
scheduling_mode=scheduling_mode,
input_layout=input_layout))
input_layout=input_layout,
enable_skip_softmax=enable_skip_softmax))
specs.append(
kernel_spec(
@ -3883,7 +3921,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=return_softmax,
scheduling_mode=scheduling_mode,
input_layout=input_layout))
input_layout=input_layout,
enable_skip_softmax=enable_skip_softmax))
specs.append(
kernel_spec(
@ -3915,7 +3954,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=return_softmax,
scheduling_mode=scheduling_mode,
input_layout=input_layout))
input_layout=input_layout,
enable_skip_softmax=enable_skip_softmax))
'''
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
@ -3967,7 +4007,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
sm=90,
dtype='e4m3',
sage_block_sizes=None,
output_dtype=None):
output_dtype=None,
enable_skip_softmax=False):
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
@ -4021,7 +4062,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
scheduling_mode=scheduling_mode,
input_layout=input_layout,
sage_block_sizes=sage_block_sizes,
output_dtype=output_dtype))
output_dtype=output_dtype,
enable_skip_softmax=enable_skip_softmax))
# 64 < D <=128: KV_STEP = 128
specs.append(
@ -4056,7 +4098,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
scheduling_mode=scheduling_mode,
input_layout=input_layout,
sage_block_sizes=sage_block_sizes,
output_dtype=output_dtype))
output_dtype=output_dtype,
enable_skip_softmax=enable_skip_softmax))
# 128 < D <=256: KV_STEP = 128
specs.append(
@ -4092,7 +4135,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
scheduling_mode=scheduling_mode,
input_layout=input_layout,
sage_block_sizes=sage_block_sizes,
output_dtype=output_dtype))
output_dtype=output_dtype,
enable_skip_softmax=enable_skip_softmax))
if not skip_mla_combination:
# context MLA (192x128)
@ -6374,13 +6418,21 @@ def enumerate_kernels():
enumerate_igmma_kernels(specs, sm=90)
enumerate_qgmma_kernels(specs, sm=90)
# need to add bf16 kernels if needed
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16')
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='bf16')
enumerate_qgmma_flash_warpspec_kernels(specs, sm=90, dtype='e4m3')
enumerate_qgmma_flash_warpspec_kernels(specs,
sm=90,
dtype='e4m3',
output_dtype="bf16")
for enable_skip_softmax in [False, True]:
if enable_skip_softmax and 'DISABLE_SKIP_SOFTMAX' in os.environ:
continue
enumerate_hgmma_flash_warpspec_kernels(
specs, sm=90, dtype='fp16', enable_skip_softmax=enable_skip_softmax)
enumerate_hgmma_flash_warpspec_kernels(
specs, sm=90, dtype='bf16', enable_skip_softmax=enable_skip_softmax)
enumerate_qgmma_flash_warpspec_kernels(
specs, sm=90, dtype='e4m3', enable_skip_softmax=enable_skip_softmax)
enumerate_qgmma_flash_warpspec_kernels(
specs,
sm=90,
dtype='e4m3',
output_dtype="bf16",
enable_skip_softmax=enable_skip_softmax)
# For now SageAttention only needs BF16
# block_size_q should be divisible by 64

View File

@ -256,7 +256,8 @@ struct Compute
actual_kv_seqlen, alibi_head_scale, \
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, kv_step_idx == kv_idx_end - 1);
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);
////////////////////////////////////////////////////////////////////////////////////////////////
@ -360,6 +361,12 @@ struct Compute
// Contiguous QKV FMHA assumes q, and kv have the same sequence length.
int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
// Update threshold of Skip-Softmax
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
{
softmax.skip_softmax_threshold = params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
}
// Calculate the alibi head_scaling_factor.
float alibi_head_scale
= APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(head_info.bidh, params.alibi_params) : 0.f;
@ -513,6 +520,13 @@ struct Compute
}
}
}
#ifdef SKIP_SOFTMAX_STAT
if (tidx == 0)
{
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////
@ -522,8 +536,15 @@ struct Compute
Compute_tile_o& ctile_o, float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M],
int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, Circular_buffer_kv_reader& cbr_v,
OrderedMutexAccessor& mutex, bool complete = false)
OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, bool complete = false)
{
// Skip-softmax vote initialization
if (tidx == 0)
{
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before voting.
*skip_softmax_vote = 1;
}
// load the scales of K/V from global memory
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
if constexpr (block_size > 0) \
@ -557,6 +578,10 @@ struct Compute
// Ctile_p is only used once by each n step.
ctile_p.clear();
// If skip_softmax is enabled, make sure there is no racing between the initialization and writing of
// skip_softmax_vote.
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);
// BMM1 (Q x K').
warpgroup_arrive();
@ -626,8 +651,22 @@ struct Compute
softmax.apply_alibi_and_mask<APPLY_MASK>(
ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset);
// Softmax Exp, max/sum, and update scales.
softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum);
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote))
{
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
{
// Notify another warpgroup to execute QGMMA.
mutex.named_bar_arrive();
}
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
int ready2 = cbr_v.peek();
if (!ready2)
{
cbr_v.wait();
}
return;
}
// experiments show that here is the best place to load scales of V
float scales_v[SAGE_BLOCKS_PER_STEP_V];

View File

@ -17,6 +17,8 @@
#pragma once
#include "fmha/hopper/arrive_wait.h"
#include <fmha/softmax.h>
#include <fmha/traits.h>
#include <fmha/utils.h>
@ -104,6 +106,12 @@ struct Softmax_base
CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK
};
// There are 2 warpgroups so 0x3 and 0x4 are used
enum
{
SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID
};
// Ctor.
template <typename Params>
inline __device__ Softmax_base(Params params, int tidx)
@ -114,6 +122,11 @@ struct Softmax_base
, log2_chunked_attention_size_(params.log2_chunked_attention_size)
, packed_mask_ptr_{reinterpret_cast<uint32_t*>(params.packed_mask_ptr)}
, params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes}
#ifdef SKIP_SOFTMAX_STAT
, total_blocks(0)
, skipped_blocks(0)
#endif
, skip_softmax_threshold(0)
{
int warp = tidx / 32;
@ -330,24 +343,22 @@ struct Softmax_base
}
// Calculate max/sum, and update flash-attention scales.
// Returns false if skipped due to skip-softmax attention feature.
template <bool IS_FIRST_COL>
inline __device__ void compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
inline __device__ bool compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
{
float const scale = reinterpret_cast<float const&>(scale_bmm1_);
// whether this warpgroup skips the softmax
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
bool skip = may_skip;
// Row-wise max of current tile.
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
{
if (IS_FIRST_COL)
{
local_max_[mi] = elt_[mi][0];
}
else
{
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
}
local_max_[mi] = elt_[mi][0];
#pragma unroll
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
{
@ -355,6 +366,56 @@ struct Softmax_base
}
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
if constexpr (may_skip)
{
// AND(&) the CORES_M results, then `skip` means whether to skip
// the CORES_M(=2) rows
if constexpr (!EXP2F_OPTIMIZATION)
{
skip &= expf(local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
}
else
{
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
}
}
if (!IS_FIRST_COL)
{
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
}
}
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
{
#ifdef SKIP_SOFTMAX_STAT
total_blocks++;
#endif
if constexpr (may_skip)
{
// AND(&) the results together in a warp, then `skip` means whether to skip
// all the 16 rows managed by this warp.
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
// instead of 0xffffffff. But the perf is the same.
skip = __all_sync(0xffffffff, skip);
if (threadIdx.x % 32 == 0)
{
// The leader of each warp votes.
atomicAnd(skip_softmax_vote, uint32_t(skip));
}
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
named_barrier_wait(SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
skip = *((uint32_t volatile*) skip_softmax_vote);
if (skip)
{
#ifdef SKIP_SOFTMAX_STAT
skipped_blocks++;
#endif
return false;
}
}
}
// Softmax Exp.
@ -436,6 +497,7 @@ struct Softmax_base
global_max[mi] = max_new;
}
}
return true;
}
// Update flash attention scales and pack elements for BMM2.
@ -513,6 +575,13 @@ struct Softmax_base
float correction_[Mma_tile_p::CORES_M];
// The packed mask.
uint4 packed_mask_;
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
float skip_softmax_threshold;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax
uint32_t total_blocks;
uint32_t skipped_blocks;
#endif
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -868,9 +937,10 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
}
// Calculate max/sum, and update flash-attention scales.
// Returns false if skipped due to skip-softmax attention feature.
template <bool IS_FIRST_COL>
inline __device__ void compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
inline __device__ bool compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
{
float const scale = reinterpret_cast<float const&>(this->scale_bmm1_);
float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_;
@ -878,18 +948,15 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
float(&correction_)[Mma_tile_p::CORES_M] = this->correction_;
float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_;
// whether this warpgroup skips the softmax
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
bool skip = may_skip;
// Row-wise max of current tile.
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
{
if (IS_FIRST_COL)
{
local_max_[mi] = elt_[mi][0];
}
else
{
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
}
local_max_[mi] = elt_[mi][0];
#pragma unroll
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
{
@ -897,6 +964,56 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
}
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
// AND(&) the CORES_M results, then `skip` means whether to skip
// the CORES_M(=2) rows
if constexpr (may_skip)
{
// AND(&) the CORES_M results, then `skip` means whether to skip
// the CORES_M(=2) rows
if constexpr (!EXP2F_OPTIMIZATION)
{
skip &= expf(local_max_[mi] - global_max[mi]) < this->skip_softmax_threshold;
}
else
{
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < this->skip_softmax_threshold;
}
}
if (!IS_FIRST_COL)
{
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
}
}
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
{
#ifdef SKIP_SOFTMAX_STAT
this->total_blocks++;
#endif
if constexpr (may_skip)
{
// AND(&) the results together in a warp, then `skip` means whether to skip
// all the 16 rows managed by this warp.
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
// instead of 0xffffffff. But the perf is the same.
skip = __all_sync(0xffffffff, skip);
if (threadIdx.x % 32 == 0)
{
// The leader of each warp votes.
atomicAnd(skip_softmax_vote, uint32_t(skip));
}
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
named_barrier_wait(Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
skip = *((uint32_t volatile*) skip_softmax_vote);
if (skip)
{
#ifdef SKIP_SOFTMAX_STAT
this->skipped_blocks++;
#endif
return false;
}
}
}
// Softmax Exp.
@ -987,6 +1104,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
global_max[mi] = max_new;
}
}
return true;
}
// Update flash attention scales and pack elements for BMM2.

View File

@ -71,6 +71,8 @@ template <
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
// Save softmax stats ?
bool RETURN_SOFTMAX_STATS_ = false,
// Enable skip softmax attention feature
bool ENABLE_SKIP_SOFTMAX_ = false,
// The output type (only used by fp8 kernels).
typename OutputType = typename Instruction_traits<STEP_Q_, STEP_KV_, 0, false, false>::A_type,
// The sage attention block size for Q, K and V
@ -290,6 +292,12 @@ struct Kernel_traits
USE_CUSTOM_MASK = ATTENTION_MASK_TYPE_ == 3
};
// Are we enabling skip softmax attention feature?
enum
{
ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_
};
static_assert(!USE_CUSTOM_MASK || STEP_KV == 64 || STEP_KV == 128 || STEP_KV == 256, "Not implemented!");
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
@ -384,6 +392,8 @@ struct Kernel_traits
// Named barrier ids
static constexpr int DMA_SYNC_BARRIER_ID = 0x1;
static constexpr int MMA_SYNC_BARRIER_ID = 0x2;
// There are 2 warpgroups so 0x3 and 0x4 are used for skip-softmax
static constexpr int SKIP_SOFTMAX_BARRIER_ID = 0x3;
// How many threads get involved in the dma group.
enum
@ -518,6 +528,10 @@ struct Kernel_traits
// Mutex
OrderedMutex compute_mutex;
// 4 warps in a warpgroup vote to an atomic variable in shared memory
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive KV_STEPS.
uint32_t skip_softmax_votes[2][NUM_COMPUTE_GROUPS];
inline __device__ void init(int tid0)
{
@ -580,6 +594,8 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
// Save softmax stats ?
bool RETURN_SOFTMAX_STATS_ = false,
// Enable skip softmax attention feature
bool ENABLE_SKIP_SOFTMAX_ = false,
// The output type (only used by fp8 kernels).
typename OutputType = e4m3_t,
// The sage attention block size for Q, K and V
@ -588,14 +604,15 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>
RETURN_SOFTMAX_STATS_, ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_,
SAGE_BLOCK_SIZE_V_>
{
// Base class.
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_,
OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
enum
{
@ -693,6 +710,10 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
// Mutex
OrderedMutex compute_mutex;
// 4 warps in a warpgroup vote to an atomic variable in shared memory
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive STEP_KVs.
uint32_t skip_softmax_votes[2][Base::NUM_COMPUTE_GROUPS];
inline __device__ void init(int tid0)
{

View File

@ -276,7 +276,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
// scale factors
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
// flags
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi)
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi,
float const skip_softmax_threshold_scale_factor)
{
memset(&params, 0, sizeof(params));
@ -421,6 +422,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
params.enable_i2f_trick
= -double(1 << 22) * double(scale_bmm2) <= -128.f && double(1 << 22) * double(scale_bmm2) >= 127.f;
}
// Skip-softmax attention
params.skip_softmax_threshold_scale_factor = skip_softmax_threshold_scale_factor;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -429,7 +433,7 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
bool const force_non_flash_attention, bool const force_non_warp_specialization,
bool const force_non_granular_tiling, bool const force_fp32_acc,
bool const force_non_granular_tiling, bool const force_fp32_acc, float const skip_softmax_threshold_scale_factor,
// device props
const cudaDeviceProp props)
{
@ -470,6 +474,9 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
"are not supported on Ada currently.\n");
launch_params.use_granular_tiling = false;
}
// Enable skip softmax attention or not.
launch_params.enable_skip_softmax = skip_softmax_threshold_scale_factor > 0.f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -589,6 +596,9 @@ int main(int argc, char** argv)
// Use attention sinks (added to the denominator of softmax)
bool use_attention_sinks = false;
// Skip-softmax attention
float skip_softmax_threshold_scale_factor = 0;
// Read the parameters from the command-line.
for (int ii = 1; ii < argc; ++ii)
{
@ -885,6 +895,10 @@ int main(int argc, char** argv)
{
use_attention_sinks = true;
}
else if (!strcmp(argv[ii], "-skip-softmax-threshold-scale-factor") && ++ii < argc)
{
skip_softmax_threshold_scale_factor = strtof(argv[ii], nullptr);
}
else
{
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
@ -1057,7 +1071,7 @@ int main(int argc, char** argv)
Launch_params launch_params;
determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, interleaved,
ignore_b1opt, force_unroll, use_tma, force_non_flash_attention, force_non_warp_specialization,
force_non_granular_tiling, force_fp32_acc, props);
force_non_granular_tiling, force_fp32_acc, skip_softmax_threshold_scale_factor, props);
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
const size_t qkv_size = s * b * h * (2 * d + dv);
@ -1713,7 +1727,13 @@ int main(int argc, char** argv)
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
use_int8_scale_max, interleaved, is_s_padded, has_alibi);
use_int8_scale_max, interleaved, is_s_padded, has_alibi, skip_softmax_threshold_scale_factor);
#ifdef SKIP_SOFTMAX_STAT
FMHA_CHECK_CUDA(cudaMalloc(&params_v2.skip_softmax_total_blocks, sizeof(uint32_t)));
FMHA_CHECK_CUDA(cudaMalloc(&params_v2.skip_softmax_skipped_blocks, sizeof(uint32_t)));
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_total_blocks, 0, sizeof(uint32_t)));
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_skipped_blocks, 0, sizeof(uint32_t)));
#endif
// total number of tokens is needed to set TMA desc on the host.
launch_params.total_q_seqlen = q_seqlens[b];
@ -2101,6 +2121,18 @@ int main(int argc, char** argv)
non_fused_elapsed / fused_elapsed, total_flops / (fused_elapsed / float(runs) / 1e-9),
total_bytes / (fused_elapsed / float(runs) / 1e-6));
}
#ifdef SKIP_SOFTMAX_STAT
if (skip_softmax_threshold_scale_factor > 0)
{
uint32_t total_blocks, skipped_blocks;
FMHA_CHECK_CUDA(
cudaMemcpy(&total_blocks, params_v2.skip_softmax_total_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
FMHA_CHECK_CUDA(cudaMemcpy(
&skipped_blocks, params_v2.skip_softmax_skipped_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
printf("Skip-Softmax .: %u / %u = %.2f%%\n", skipped_blocks, total_blocks,
total_blocks ? 100.f * skipped_blocks / total_blocks : 0.f);
}
#endif
#if defined(DEBUG_HAS_PRINT_BUFFER)
FMHA_CHECK_CUDA(cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32));
@ -2141,6 +2173,11 @@ int main(int argc, char** argv)
FMHA_CHECK_CUDA(cudaFree(kv_cache_block_offsets_d));
FMHA_CHECK_CUDA(cudaFree(contiguous_kv_d));
FMHA_CHECK_CUDA(cudaFree(softmax_stats_d));
FMHA_CHECK_CUDA(cudaFree(attention_sinks_d));
#ifdef SKIP_SOFTMAX_STAT
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_total_blocks));
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_skipped_blocks));
#endif
free(qkv_h);
free(mask_h);

View File

@ -283,6 +283,16 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
float* scales;
} q, k, v;
} sage;
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
// A positive value means skip-softmax is enabled.
float skip_softmax_threshold_scale_factor = 0;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax, pointers of device memory for output
uint32_t* skip_softmax_total_blocks;
uint32_t* skip_softmax_skipped_blocks;
#endif
};
#endif
@ -322,6 +332,8 @@ struct Fused_multihead_attention_launch_params
// harward properties to determine how to launch blocks
int multi_processor_count = 0;
int device_l2_cache_size = 0;
// skip softmax attention
bool enable_skip_softmax = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -177,4 +177,13 @@ struct Fused_multihead_attention_params_v2
float* scales;
} q, k, v;
} sage;
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
// A positive value means skip-softmax is enabled.
float skip_softmax_threshold_scale_factor = 0;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax, pointers of device memory for output
uint32_t* skip_softmax_total_blocks;
uint32_t* skip_softmax_skipped_blocks;
#endif
};

View File

@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
#define SLIDING_WINDOW 0
#endif
#ifndef SKIP_SOFTMAX_ATTN
#define SKIP_SOFTMAX_ATTN 0
#endif
#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
#endif
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
#endif
// 0 - no PDL
// 1 - naive PDL
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)

View File

@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
asm volatile("trap;\n");
return 0;
}();
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
return MatDesc{
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),

View File

@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
#endif
#ifndef GENERATE_CUBIN
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream)
{
@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
uint32_t const nbQHeads = nbKHeads * headGrpSize;
// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}();
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
#if SPEC_DEC
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);

View File

@ -90,6 +90,9 @@ struct BeamSearchParams
// match trt-llm API.
};
uint32_t computeNbSubSeqPerSeqMHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);

View File

@ -49,6 +49,10 @@ static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is to
#define SWAP_AB (!SPEC_DEC)
#endif
#if SKIP_SOFTMAX_ATTN
static_assert(SWAP_AB && USE_PAGED_KV_CACHE && !SPEC_DEC && BEAM_WIDTH == 1, "SKIP_SOFTMAX_ATTN is not supported.");
#endif
#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT)
inline constexpr bool swapAB = SWAP_AB;
@ -138,26 +142,38 @@ using PaddedOutHead = PaddedInputHead;
struct alignas(128) SharedMem
{
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
using KBuffer = Array2D<LdGrain, gemm0CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes)>;
static constexpr uint32_t nbKBuf = 2;
KBuffer k[nbKBuf]; // as is loaded from global mem.
using XBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerXPart>, nbXParts>;
static constexpr uint32_t nbXBuf
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
using VBuffer = Vec<Array2D<LdGrain, gemm1CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes),
sizeof(XBuffer) % (cacheHeadPartBytes * 8) == 0>,
cacheHeadNbParts>;
#if !SWAP_AB
using VTBuffer = Array2D<LdGrain, headElems, exactDiv(gemm1CtaTileNbTokens, cacheElemsPerGrain), true>;
#endif
static constexpr uint32_t nbVBuf = 2;
#if CACHE_ELEM_ENUM == 0
using OutSwizzleBuf = Array2D<LdGrain, ctaNbQHeads, grainsPerPaddedInputHead>;
#elif CACHE_ELEM_ENUM == 2
using OutSwizzleBuf = Array2D<Vec<Vec<InputElem, 4>, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>;
#endif
#if SKIP_SOFTMAX_ATTN
static constexpr uint32_t nbKBuf = 2;
static constexpr uint32_t nbVBuf = 3; // @fixme: skip_softmax_attn: for skip softmax attn, an extra VBuffer is used
static constexpr uint32_t nbXBuf
= 3 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
#else
static constexpr uint32_t nbKBuf = 2;
static constexpr uint32_t nbVBuf = 2;
static constexpr uint32_t nbXBuf
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
#endif
static_assert(nbXBuf == nbVBuf);
// note: buffers used for GMMA may have additional alignment requirements
KBuffer k[nbKBuf]; // as is loaded from global mem.
QBuffer q; // For gmma math. Conversion done if needed.
union ReusedXVOutSwizzleBuf
{
struct XV
@ -196,9 +212,6 @@ struct alignas(128) SharedMem
return reusedXVOutSwizzleBuf[i].outSwizzle;
}
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
QBuffer q; // For gmma math. Conversion done if needed.
// @fixme: move these into reusedXVOutSwizzleBuf
#if SWAP_AB
ShmQWiseVec xColMax[nbXBuf];
@ -220,6 +233,11 @@ struct alignas(128) SharedMem
Vec<KVCachePageIndex, nbPagesPerTile> pages[2]; // one for K and one for V
#endif
#if SKIP_SOFTMAX_ATTN
uint32_t skipSoftmaxVotesGemm0ToV[nbXBuf]; // guarded by skipSoftmaxXBar
uint32_t skipSoftmaxVotesGemm0ToGemm1[nbXBuf]; // guarded by xBar
#endif
// mem barriers
CtaBarrierPair qBar;
@ -229,6 +247,9 @@ struct alignas(128) SharedMem
CtaBarrierPair vtBar[nbVBuf];
#endif
CtaBarrierPair xBar[nbXBuf];
#if SKIP_SOFTMAX_ATTN
CtaBarrierPair skipSoftmaxXBar[nbXBuf]; // for V to wait for X to be ready
#endif
// used internally in the gemm0 warp group
// @fixme: use separate arrive and wait for all usage
@ -425,8 +446,13 @@ __device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
#endif
#if SWAP_AB
#if SKIP_SOFTMAX_ATTN
__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src,
float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip);
#else
__device__ RegColWiseVec computeWarpGrpColMax_sync(
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
#endif
__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd);
__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax);
__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
@ -675,6 +701,12 @@ CUBIN_EXPORT __global__
#endif
#if SPEC_DEC
SpecDecParams const specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* __restrict__ const semaphores
= nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
@ -753,6 +785,10 @@ CUBIN_EXPORT __global__
uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
assert(isMultiBlockMode == (nbSubSeq > 1));
#if SKIP_SOFTMAX_ATTN
bool const disableSkipForShortSeq = (cacheSeqLen < skipSoftmaxThresholdScaleFactor);
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / cacheSeqLen;
#endif
if (idxSubSeq >= nbSubSeq)
{
return;
@ -776,21 +812,34 @@ CUBIN_EXPORT __global__
assert(dynamicSmemSize() >= sizeof(SharedMem));
SharedMem& smem = *reinterpret_cast<SharedMem*>(&smemByteBuf[0]);
constexpr uint32_t nbBuffers = 2;
static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf);
if (wid < nbBuffers)
constexpr uint32_t maxNbBuffers = (SharedMem::nbXBuf > SharedMem::nbVBuf) ? SharedMem::nbXBuf : SharedMem::nbVBuf;
static_assert(
maxNbBuffers >= SharedMem::nbKBuf && maxNbBuffers >= SharedMem::nbVBuf && maxNbBuffers >= SharedMem::nbXBuf);
if (wid < maxNbBuffers)
{
if (warpElectSync())
{
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
#if !SWAP_AB
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
if (wid < SharedMem::nbKBuf)
{
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
}
if (wid < SharedMem::nbXBuf)
{
#if SKIP_SOFTMAX_ATTN
smem.skipSoftmaxXBar[wid].initialize(gemm0NbThrds + warp_size, gemm0NbThrds + warp_size);
smem.vBar[wid].initialize(gemm1NbThrds + warp_size, gemm1NbThrds + warp_size);
#else
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
#endif
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
#if !SWAP_AB
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
#endif
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
}
}
}
else if (wid == nbBuffers)
else if (wid == maxNbBuffers)
{
if (warpElectSync())
{
@ -819,6 +868,10 @@ CUBIN_EXPORT __global__
SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen};
#endif
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t localSkippedBlockCount = 0;
#endif
// QK gemm
constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM);
using Acc = GmmaAcc<gemm0CtaTileNbTokens, ctaNbQHeads>;
@ -940,10 +993,39 @@ CUBIN_EXPORT __global__
}
}
#endif
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
auto& xBar = smem.xBar[idxXBuf];
// update colMax in shared mem and get a register copy
#if SWAP_AB
#if SKIP_SOFTMAX_ATTN
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
skipSoftmaxXBar.consumed.arrive_and_wait();
bool const maybeSkip = !disableSkipForShortSeq && idxIter != 0;
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc,
skipSoftmaxThreshold, &smem.skipSoftmaxVotesGemm0ToV[idxXBuf], maybeSkip);
bool const shouldSkipSoftmaxAttn = static_cast<bool>(smem.skipSoftmaxVotesGemm0ToV[idxXBuf]);
unused(skipSoftmaxXBar.produced.arrive());
warpGrpOnlineSoftmax(acc, colMax);
if (shouldSkipSoftmaxAttn)
{
xBar.consumed.arrive_and_wait();
if (threadIdx.x == 0)
{
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 1U;
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
localSkippedBlockCount++;
#endif
}
asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used
unused(xBar.produced.arrive());
continue;
}
#else
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc);
warpGrpOnlineSoftmax(acc, colMax);
#endif
#else
RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc);
warpGrpOnlineSoftmax(acc, rowMax);
@ -959,8 +1041,6 @@ CUBIN_EXPORT __global__
// map 1 to fp8_max before conversion to fp8
acc = acc * kE4M3_MAX;
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
auto& xBar = smem.xBar[idxXBuf];
// @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM.
#if SWAP_AB
storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc);
@ -989,13 +1069,25 @@ CUBIN_EXPORT __global__
storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax);
storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum);
#endif
#if SKIP_SOFTMAX_ATTN
if (threadIdx.x == 0)
{
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 0;
}
#endif
__syncwarp();
// the release semantics of arrive does not work for async consumers like gmma. additional fence is
// needed.
asm volatile("fence.proxy.async.shared::cta;\n");
unused(xBar.produced.arrive());
}
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
if (threadIdx.x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr)
{
atomicAdd(skippedBlockCount, localSkippedBlockCount);
atomicAdd(totalBlockCount, nbIters);
}
#endif
unused(smem.qBar.consumed.arrive());
}
else if (warpIdx.z == 1)
@ -1043,216 +1135,231 @@ CUBIN_EXPORT __global__
uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq;
auto const idxVBuf = idxIter % SharedMem::nbVBuf;
auto const idxXBuf = idxVBuf;
auto& vBar = smem.vBar[idxVBuf];
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
auto const& vBuf = smem.vBuf(idxVBuf);
#if !SWAP_AB
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
auto& vtBuf = smem.vtBuf(idxVBuf);
vtBar.consumed.arrive_and_wait();
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
vBar.consumed.arrive();
vtBar.produced.arrive();
#endif
auto& xBar = smem.xBar[idxXBuf];
auto& vBar = smem.vBar[idxVBuf];
auto const& vBuf = smem.vBuf(idxVBuf);
xBar.produced.arrive_and_wait();
#if SKIP_SOFTMAX_ATTN
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf]; // guarded by xBar
if (shouldSkipSoftmaxAttn)
{
vBar.produced.arrive_and_wait();
}
#endif
#if SKIP_SOFTMAX_ATTN
if (!shouldSkipSoftmaxAttn) // skip XVGemm
#endif
{
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
#if !SWAP_AB
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
auto& vtBuf = smem.vtBuf(idxVBuf);
vtBar.consumed.arrive_and_wait();
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
vBar.consumed.arrive();
vtBar.produced.arrive();
#endif
#if !defined(NDEBUG) && DBG_PRINT
#if SWAP_AB
if (threadIdx.x == 0)
{
printf("colMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xColMax[idxXBuf][i]);
}
printf("\n");
printf("colSum:\n");
for (int n = 0; n < 4; n++)
if (threadIdx.x == 0)
{
printf("colMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
printf("%f, ", smem.xColMax[idxXBuf][i]);
}
printf("\n");
printf("colSum:\n");
for (int n = 0; n < 4; n++)
{
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
}
printf("\n");
}
printf("\n");
printf("X:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
{
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
printf("%.2f, ", float(e));
if (j % 16 == 15)
{
printf("| ");
}
}
printf("\n\n");
}
}
smem.gemm1WarpGrpBar.arrive_and_wait();
#else
if (blockIdx.y == 1 && threadIdx.x == 0)
{
printf("rowMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowMax[idxXBuf][i]);
}
printf("\n");
printf("rowSum:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowSum[idxXBuf][i]);
}
printf("\n");
}
printf("\n");
printf("X:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
{
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
printf("%.2f, ", float(e));
if (j % 16 == 15)
{
printf("| ");
}
}
printf("\n\n");
}
}
smem.gemm1WarpGrpBar.arrive_and_wait();
#else
if (blockIdx.y == 1 && threadIdx.x == 0)
{
printf("rowMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowMax[idxXBuf][i]);
}
printf("\n");
printf("rowSum:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowSum[idxXBuf][i]);
}
printf("\n");
}
smem.gemm1WarpGrpBar.arrive_and_wait();
smem.gemm1WarpGrpBar.arrive_and_wait();
#endif
#endif
#if SWAP_AB
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
#else
rescaleGemm1AccForNewRowMax_sync(
warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf],
smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
#endif
auto& xBuf = smem.xBuf(idxXBuf);
auto& xBuf = smem.xBuf(idxXBuf);
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
.raw();
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
.raw();
#if CACHE_ELEM_ENUM == 0
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
.raw();
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
.raw();
#endif
#if SWAP_AB
//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed.
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
{
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
{
#if CACHE_ELEM_ENUM == 2
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
#if !defined(NDEBUG) && DBG_PRINT
if (threadIdx.x == 0)
{
printf("fragA:\nidxInstK == %u\n", idxInstK);
}
smem.gemm1WarpGrpBar.arrive_and_wait();
for (int m = 0; m < 2; m++)
{
for (int w = 0; w < 4; w++)
if (threadIdx.x == 0)
{
if (warpRank == w)
printf("fragA:\nidxInstK == %u\n", idxInstK);
}
smem.gemm1WarpGrpBar.arrive_and_wait();
for (int m = 0; m < 2; m++)
{
for (int w = 0; w < 4; w++)
{
if (laneId() == 0)
if (warpRank == w)
{
printf(" warpRank = %u\n", warpRank);
}
__syncwarp();
for (int a = 0; a < 2; a++)
{
for (int b = 0; b < 8; b++)
if (laneId() == 0)
{
for (int c = 0; c < 2; c++)
printf(" warpRank = %u\n", warpRank);
}
__syncwarp();
for (int a = 0; a < 2; a++)
{
for (int b = 0; b < 8; b++)
{
for (int d = 0; d < 4; d++)
for (int c = 0; c < 2; c++)
{
if (laneId() == b * 4 + d)
for (int d = 0; d < 4; d++)
{
for (int e = 0; e < 4; e++)
if (laneId() == b * 4 + d)
{
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
fragA[m](0, c)(a, 0));
printf("%.2f, ", float(elem4[e]));
for (int e = 0; e < 4; e++)
{
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
fragA[m](0, c)(a, 0));
printf("%.2f, ", float(elem4[e]));
}
}
__syncwarp();
}
__syncwarp();
}
if (laneId() == 0)
{
printf("\n");
}
__syncwarp();
}
if (laneId() == 0)
if (laneId() == 0 && a == 0)
{
printf("\n");
printf("----------------------\n");
}
__syncwarp();
}
if (laneId() == 0 && a == 0)
{
printf("----------------------\n");
}
__syncwarp();
}
smem.gemm1WarpGrpBar.arrive_and_wait();
}
smem.gemm1WarpGrpBar.arrive_and_wait();
}
}
#endif
#endif
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
#if CACHE_ELEM_ENUM == 2
gmma::fence();
gmma::fence();
#endif
#pragma unroll
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
{
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
{
#if CACHE_ELEM_ENUM == 0
auto const descV
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
descV, descX, true);
auto const descV
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
descV, descX, true);
#elif CACHE_ELEM_ENUM == 2
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
#endif
}
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
// gmma.
gmma::wait_group<0>();
}
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
// gmma.
gmma::wait_group<0>();
}
#else
auto const descVTBase = gmma::makeMatDesc(
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
.raw();
vtBar.produced.arrive_and_wait();
auto const descVTBase = gmma::makeMatDesc(
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
.raw();
vtBar.produced.arrive_and_wait();
// if (idxIter == 1 && threadIdx.x == 0) {
// printf("vtBuf:\n");
// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf);
// }
#pragma unroll
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
{
#pragma unroll
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
{
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
auto const descVT = addAddr(
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
gmma::mma_async_shmA<MathElem, headElems>(
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
descVT, true);
#pragma unroll
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
{
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
auto const descVT = addAddr(
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
gmma::mma_async_shmA<MathElem, headElems>(
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
descVT, true);
}
}
}
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma.
gmma::wait_group<0>();
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
// gmma.
gmma::wait_group<0>();
#endif
}
if (idxIter == nbIters - 1)
{
// gmma::wait_group should have already synchronized threads, so this may be unnecessary.
@ -1471,8 +1578,24 @@ CUBIN_EXPORT __global__
tensorMap
#endif
};
#if SKIP_SOFTMAX_ATTN
for (auto& b : smem.skipSoftmaxXBar)
{
unused(b.consumed.arrive());
}
#endif
for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++)
{
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
auto& vBar = smem.vBar[idxVBuf];
#if SKIP_SOFTMAX_ATTN
uint32_t idxXBuf = idxIter % SharedMem::nbXBuf;
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
skipSoftmaxXBar.produced.arrive_and_wait();
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToV[idxXBuf];
skipSoftmaxXBar.consumed.arrive();
#endif
uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq;
vTileLoader.loadPages(idxVTile);
#if USE_INPUT_KV || ENABLE_PDL == 2
@ -1506,8 +1629,20 @@ CUBIN_EXPORT __global__
}
#endif
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
auto& vBar = smem.vBar[idxVBuf];
#if SKIP_SOFTMAX_ATTN
if (shouldSkipSoftmaxAttn)
{
vBar.consumed.arrive_and_wait();
// compared to non-skip softmax attn, we need to increase vBar.produced count to avoid race
// condition where vBar.consumed is arrived again without wait without skip softmax attn, XVGemm
// will wait for tx_count, so its progress won't go ahead of vload warp with skip softmax attn,
// XVGemm WG may go ahead of vload warp, as previous vBar only have XVGemm WG threads and a tx_count
// (now = 0). Then it may arrive vBar.consumed before it is arrive_and_wait-ed
vBar.produced.arrive();
continue;
}
#endif
vBar.consumed.arrive_and_wait();
if (warpElectSync())
{
@ -1517,6 +1652,9 @@ CUBIN_EXPORT __global__
vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced);
}
}
#if SKIP_SOFTMAX_ATTN
vBar.produced.arrive();
#endif
__syncwarp();
}
}
@ -1992,9 +2130,23 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
#endif // SPEC_DEC
// smemColMax is persistent across multiple iterations
#if SKIP_SOFTMAX_ATTN
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax,
Gemm0Acc const& src, float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip)
#else
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src)
#endif
{
#if SKIP_SOFTMAX_ATTN
if (threadIdx.x == 0)
{
*smemSkipVote = maybeSkip ? 1U : 0U; // will sync before vote
}
float const lnThreshold
= log(skipSoftmaxThreshold); // this can be -inf, but should be safe as we only use it for comparison
#endif
auto colMax = RegColWiseVec::filled(Vec<float, 2>::filled(safeInitRowMax));
#pragma unroll
for (uint32_t n = 0; n < src.cols; n++)
@ -2029,6 +2181,9 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
}
uint32_t const lane = laneId();
#if SKIP_SOFTMAX_ATTN
auto prevOrCurrentMax = RegColWiseVec();
#if SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
if (lane < 4)
{
#pragma unroll
@ -2037,12 +2192,43 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
#pragma unroll
for (uint32_t j = 0; j < 2; j++)
{
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
prevOrCurrentMax[n][j] = smemColMax[8 * n + 2 * lane + j];
}
}
}
warpGrpBar.arrive_and_wait();
#endif
#endif
if (lane < 4)
{
#pragma unroll
for (uint32_t n = 0; n < src.cols; n++)
{
#pragma unroll
for (uint32_t j = 0; j < 2; j++)
{
#if SKIP_SOFTMAX_ATTN && !SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
// prevOrCurrentMax <= actual smemColMax (after updates from all 4 warps done), but always >=
// smemColMax(Prev), the smemColMax value *before* this tile is computed.
// When determine whether to skip, it is safe to use prevOrCurrentMax: 1) all 4 warps' localmax <
// smemColMax(Prev), then prevOrCurrentMax == smemColMax(Prev), result not affected; 2) if some localmax
// > smemColMax(Prev), prevOrCurrentMax > smemColMax(Prev), some warps may incorrectly vote skip, but
// at least one warp whose localColMax is larger will not skip, then the tile is not skipped.
// This reduces some sync and check, but has issue when threshold > 1.
prevOrCurrentMax[n][j] = atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
#else
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
#endif
}
}
}
warpGrpBar.arrive_and_wait();
uint32_t const idxInQuad = lane % 4;
#if SKIP_SOFTMAX_ATTN
bool localShouldSkip = true;
#endif
#pragma unroll
for (uint32_t n = 0; n < src.cols; n++)
@ -2050,10 +2236,21 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
{
#if SKIP_SOFTMAX_ATTN
if (lane < 4 && 8 * n + 2 * idxInQuad + j < headGrpSize)
{
localShouldSkip &= (colMax[n][j] - prevOrCurrentMax[n][j]) < lnThreshold;
}
#endif
assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]);
colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j];
}
}
#if SKIP_SOFTMAX_ATTN
atomicAnd(smemSkipVote, static_cast<uint32_t>(localShouldSkip)); // this will be translated to redux and voteu
#endif
warpGrpBar.arrive_and_wait();
return colMax;
}
@ -2199,7 +2396,7 @@ __device__ inline void storeGemm0AccToShm(
uint32_t const idxOctInsideHalf = idxInHalf / 8;
uint32_t const idxRowInsideOct = lane % 8;
uint32_t const warpBaseC = 16 * warpRank;
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair<uint32_t, uint32_t>
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> mha::pair<uint32_t, uint32_t>
{
uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols;
uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols;
@ -3231,6 +3428,24 @@ __device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst,
}
#ifndef GENERATE_CUBIN
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
{
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
float const factor = 0.25f;
return mha::min<uint32_t>(
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
divUp(maxSeqLen, gemm0CtaTileNbTokens));
}
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -3268,6 +3483,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream)
{
@ -3286,22 +3507,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
uint32_t const nbVHeads = nbKHeads;
uint32_t const nbQHeads = nbKHeads * headGrpSize;
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
{
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
float const factor = 0.25f;
return mha::min<uint32_t>(
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
divUp(maxSeqLen, gemm0CtaTileNbTokens));
}();
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqHopperF8MHA(prop, batchSize, nbKHeads, maxSeqLen);
#if SPEC_DEC
uint32_t const qSeqLen = specDecParams.qSeqLen;
#else
@ -3371,6 +3577,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#endif
#if SPEC_DEC
specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
skippedBlockCount, totalBlockCount,
#endif
#endif
semaphores, scratch);
#else

View File

@ -1272,6 +1272,19 @@ using is_void = is_same<remove_cv_t<T>, void>;
template <typename T>
inline constexpr bool is_void_v = is_void<T>::value;
#endif
#ifndef GENERATE_CUBIN
template <typename T1, typename T2>
using pair = std::pair<T1, T2>;
#else
template <typename T1, typename T2>
struct pair
{
T1 first;
T2 second;
};
#endif
} // namespace mha
#if GENERATE_CUBIN

View File

@ -50,7 +50,8 @@ using Vector = Matrix<Type, Size, 1>;
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
{
uint32_t const nbTiles = divUp(seqLen, tileSize);
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
@ -61,6 +62,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
float const qkScale = qScale * kvScale / sqrtf(validElemsPerHead);
uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize);
uint32_t const idxTileBeg = seqBeg / tileSize;
uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1;
std::vector<Eigen::Vector<float, headGrpSize>> skipRowMaxs(nbSubSeq);
for (uint32_t i = 0; i < nbSubSeq; i++)
{
skipRowMaxs[i].fill(-INFINITY);
}
bool const disableSkipForShortSeq = (seqLen < skipSoftmaxThresholdScaleFactor);
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / seqLen;
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
{
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> gemm0Acc;
@ -88,7 +99,22 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
}
}
Eigen::Vector<float, headGrpSize> const tileRowMax = gemm0Acc.rowwise().maxCoeff().cwiseMax(rowMax).eval();
Eigen::Vector<float, headGrpSize> const localRowMax = gemm0Acc.rowwise().maxCoeff().eval();
Eigen::Vector<float, headGrpSize> const tileRowMax = localRowMax.cwiseMax(rowMax).eval();
auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq];
skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval();
if (!disableSkipForShortSeq && skipSoftmaxThreshold > 0)
{
*totalBlockCount += 1;
auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold));
bool const skipBlock = skipSoftmaxMask.all() && ((idxTile - idxTileBeg) >= nbSubSeq);
if (skipBlock)
{
*skippedBlockCount += 1;
continue;
}
}
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> tileX
= (gemm0Acc.colwise() - tileRowMax).array().exp().eval();
@ -138,7 +164,8 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, \
float skipSoftmaxThreshold, uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);

View File

@ -88,7 +88,8 @@ struct CacheSeq<true, true>
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum);
template <typename MathElem, bool isPaged, bool useBeamSearch>
#if SPEC_DEC

View File

@ -150,7 +150,8 @@ template <uint32_t nbKHeads>
#endif
#endif
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30,
float skipSoftmaxThresholdScaleFactor = 0.0f)
{
#if IS_MLA
if (nbKHeads != 1)
@ -224,6 +225,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head.
}
ctxLen = std::min(ctxLen, seqLen);
uint32_t skippedBlockCount = 0;
uint32_t totalBlockCount = 0;
if (skipSoftmaxThresholdScaleFactor > 0)
{
assert(useQGMMA);
}
float const kScale = cacheElemSize == 2 ? 1.f : 1 / 4.f;
float const vScale = kScale;
float const qScale = 1.f;
@ -329,6 +336,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
auto const rcpOutScale = ManagedMemBuf<float>(1);
auto const seqLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
auto const ctxLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
#if SKIP_SOFTMAX_ATTN
#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS
auto const kernelSkippedBlockCount = ManagedMemBuf<uint32_t>(1);
auto const kernelTotalBlockCount = ManagedMemBuf<uint32_t>(1);
kernelSkippedBlockCount[0] = 0;
kernelTotalBlockCount[0] = 0;
#endif
#else
EXPECT_EQ(skipSoftmaxThresholdScaleFactor, 0.0f)
<< "Got non-zero skipSoftmaxThresholdScaleFactor while SKIP_SOFTMAX_ATTN is not enabled.";
#endif
#if USE_PAGED_KV_CACHE
auto const pageListBuf = ManagedMemBuf<std::byte>(pageListBytes);
#if PAGED_KV_CACHE_LAYOUT == 1
@ -726,6 +744,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
maxSeqLen, &seqLenList[0][0], batchSize, kvCacheScale.get(), semaphores.get(), scratch, stream);
};
#else
auto multiBlockNum = [&]()
{
auto const calcFunc = useQGMMA ? &computeNbSubSeqPerSeqHopperF8MHA : &computeNbSubSeqPerSeqMHA;
return calcFunc(prop, batchSize, nbKHeads, maxSeqLen);
}();
auto runKernel = [&]()
{
auto const launchFunc = useQGMMA ? &launchHopperF8MHA : &launchMHA;
@ -776,6 +799,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
batchSize, kvCacheScale.get(),
#if SPEC_DEC
specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
kernelSkippedBlockCount.get(), kernelTotalBlockCount.get(),
#endif
#endif
semaphores.get(), scratch, stream);
checkCuda(cudaGetLastError());
@ -813,6 +842,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
checkCuda(cudaEventRecord(toc, stream));
prefetchToDevice(cudaCpuDeviceId);
checkCuda(cudaStreamSynchronize(stream));
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
kernelSkippedBlockCount[0] /= nbIters;
kernelTotalBlockCount[0] /= nbIters;
#endif
if (testPerf)
{
float ms;
@ -849,6 +882,15 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
= totalNbCacheLoadBytes + inputBytes + outputBytes; // we ignore page indices and beam search indices.
float const dramSolTime = totalTraffic / bandwidth * 1E3f;
float const dramSolRatio = dramSolTime / ms;
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
* (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]))
* nbLoadedCacheTokens;
float const totalTrafficWithSkip
= totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices.
float const dramSolTimeWithSkip = totalTrafficWithSkip / bandwidth * 1E3f;
float const dramSolRatioWithSkip = dramSolTimeWithSkip / ms;
#endif
if (verbose)
{
printf("done\n");
@ -863,7 +905,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
}
float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, tops);
#else
printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops);
#endif
}
if (refCheck)
{
@ -1084,8 +1132,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
if (useQGMMA)
{
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
refAttentionSinks);
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks,
skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum);
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
}
@ -1132,6 +1180,14 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
#endif
}
}
#if SKIP_SOFTMAX_ATTN
printf("host skippedBlockCount: %d/%d (%.2f%%)\n", skippedBlockCount, totalBlockCount,
totalBlockCount == 0 ? 0.0f : 100.0f * skippedBlockCount / totalBlockCount);
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
#endif
#endif
if (saveData)
{
fout_refOutput.close();
@ -1253,6 +1309,14 @@ TEST(RefCheck, llama_V2_70b)
#if SLIDING_WINDOW
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
#endif
#if SKIP_SOFTMAX_ATTN
runTest<1>(32, 2048, false, true, false, false, false, ~0U, 1U << 30, 0.f);
runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 1280.f);
runTest<2>(32, 4096, false, true, false, false, false, ~0U, 1U << 30, 125.f);
runTest<4>(32, 300, false, true, false, false, false, ~0U, 1U << 30, 80.f);
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 501.0f);
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 500.f);
#endif
runTest<8>(120, 367, false, true);
runTest<8>(1792, 2048, false, true);

View File

@ -157,6 +157,11 @@ set(UCX_WRAPPER_TARGET tensorrt_llm_ucx_wrapper)
if(NIXL_ROOT)
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
set(TRANSFER_AGENT_BINDING_TARGET tensorrt_llm_transfer_agent_binding)
endif()
if(MOONCAKE_ROOT)
set(MOONCAKE_WRAPPER_TARGET tensorrt_llm_mooncake_wrapper)
endif()
add_subdirectory(executor)
@ -272,6 +277,11 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
endif()
if(TARGET ${MOONCAKE_WRAPPER_TARGET})
target_link_libraries(${MOONCAKE_WRAPPER_TARGET} INTERFACE ${SHARED_TARGET})
add_dependencies(${SHARED_TARGET} ${MOONCAKE_WRAPPER_TARGET})
endif()
if(NOT WIN32)
# Load libraries at $PREFIX/lib from
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs

View File

@ -154,7 +154,8 @@ bool CacheFormatter::needSendCache(
return true;
}
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
int selfTpRankInDpGroup = selfTpRank;
if (selfConfig.getParallelConfig().mEnableAttentionDP)
{

View File

@ -81,6 +81,11 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
}
else if (common::getEnvUseMooncakeKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE;
TLLM_LOG_INFO("Enable MOONCAKE KV cache transport.");
}
else if (common::getEnvUseMPIKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
@ -203,9 +208,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState);
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
TLLM_LOG_INFO("NIXL Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
{
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());

View File

@ -358,8 +358,9 @@ public:
TransceiverTag::Id id;
RequestInfo info;
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
auto const* connection = isAgent
? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id));
if (connection == nullptr && !mManager->isRunning())
{
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
@ -395,8 +396,8 @@ public:
if (it == mRequestToSession.end())
{
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
DataContext{tagFromRequestId(requestId), mTerminate}, mSelfState, info.getTransState(),
mBufferManager, info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
!common::getEnvKVCacheTimeOutputPath().empty());
session.setTime(TransferSession::kTimeRequestInfo);
it = mRequestToSession.emplace(requestId, std::move(session)).first;
@ -685,6 +686,10 @@ private:
{
future.get();
}
if (mResponseFuture.valid())
{
mResponseFuture.get();
}
}
void removeResponse(std::map<RequestIdType, Response>::iterator it)
@ -886,9 +891,9 @@ public:
}
}
auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
&llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId), mTerminate},
mSelfState, contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(),
requestInfo.getLastBlockKey(), &llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
}
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
@ -964,7 +969,7 @@ public:
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
TLLM_CHECK(agentConnection);
isReady = agentConnection->recvReadySignal(
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG, mTerminate});
}
else
{
@ -979,6 +984,7 @@ public:
~Impl()
{
mTerminate.store(true);
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
{
asyncResource->mTerminate = true;
@ -1134,6 +1140,7 @@ private:
runtime::BufferManager mBufferManager;
std::ofstream mMeasuresFile;
std::mutex mMeasuresFileMutex;
std::atomic<bool> mTerminate{false};
};
void CacheSender::ImplDeleter::operator()(Impl* ptr)

View File

@ -1224,7 +1224,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end()
? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse)
: std::make_tuple(false, 0, nullptr);
if (matchingBlock != nullptr)
if (matchingBlock != nullptr && numMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen())
{
KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId();
@ -1338,6 +1338,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
}
}
sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens);
return numMatchedTokens;
}
@ -1555,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
}
}
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
{
SizeType32 numBlocksStoredForReuse = 0;
@ -1568,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
auto numBlocks = blockKeys.size();
std::vector<BlockPtr> storedBlocks;
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
{
auto const bid = blockIds[blockCnt];
@ -1619,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
if (pinBlocks)
{
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
}
lastStoredId = searchRoot->getBlockId();
}
if (mEventManager)
{
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
}
return {numBlocksStoredForReuse, lastStoredId};
return {numBlocksStoredForReuse, pinnedBlockIds};
}
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
@ -1714,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
}
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (auto& [_, manager] : mWindowBlockManagers)
{
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
}
return lastStoredId;
return pinnedBlockIds;
}
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
@ -1731,9 +1732,22 @@ std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
// Released block will be stored when reuse is enabled.
// Reuse is implied to be enabled if llmRequest is provided.
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
// For now, the attention kernel only accepts a single
// "prepopulatedPromptLen", that is, all window sizes will use the same
// prepopulated prompt length, so it is meaningless right now to save
// blocks only for a certain window size while blocks in the other
// window size are not valid for saving for reuse.
bool isAllWindowSizesValidForStoreForReuse = true;
for (auto& [windowSize, manager] : mWindowBlockManagers)
{
isAllWindowSizesValidForStoreForReuse &= manager.isSequenceValidForStoreForReuse(sequence.getRequestId());
}
for (auto& [_, manager] : mWindowBlockManagers)
{
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1)
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1
|| !isAllWindowSizesValidForStoreForReuse)
{
lastStoredId = manager.releaseBlocks(sequence, std::nullopt);
}
@ -1753,7 +1767,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
}
}
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
// Use the first window size
if (mWindowBlockManagers.empty())
@ -1761,7 +1775,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
return;
}
auto& firstManager = mWindowBlockManagers.begin()->second;
firstManager.unpinBlocksById(blockId);
firstManager.unpinBlocksById(blockIds);
}
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
@ -1774,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
}
}
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
if (blockIds.empty())
{
return;
}
auto block = mAllBlocksById[blockId];
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
for (auto const& blockId : blockIds)
{
block->decRefCount();
if (!block->hasRefs())
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
"Block id %d is out of range", blockId);
auto block = mAllBlocksById[blockId];
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
mEvictionPolicy->releaseBlock(block);
block->decRefCount();
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block);
}
}
block = std::move(block->getPrevBlock());
}
}
@ -1856,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
auto constexpr beamIdx = 0;
@ -1869,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
return pinnedBlockIds;
}
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
@ -1908,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
[](BlockPtr const& block) { return block->getBlockId(); });
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
sequence.getRequestId(), numBlocksStoredForReuse);
}
@ -2485,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
return lastStoredId;
}
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
auto& sequence = getSequence(requestId);
std::optional<KVCacheBlock::IdType> lastStoredId
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
return lastStoredId;
return pinnedBlockIds;
}
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
@ -2508,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
mBlockManager.pinBlocks(sequence);
}
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
mBlockManager.unpinBlocksById(blockId);
mBlockManager.unpinBlocksById(blockIds);
}
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const

View File

@ -60,7 +60,8 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
bool MLACacheFormatter::needSendCache(
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
{
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize

View File

@ -296,7 +296,13 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
// Parameters for sparse attention
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
// Skip softmax threshold.
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax, pointers of device memory for output
xqaParams.skip_softmax_total_blocks = mSkipSoftmaxTotalBlocks;
xqaParams.skip_softmax_skipped_blocks = mSkipSoftmaxSkippedBlocks;
#endif
// Cross attention parameters.
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
@ -1313,6 +1319,8 @@ int AttentionOp::mlaGeneration(
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
}
// MLA does not support skip-softmax attention right now
// Run the fmha kernel
mDecoderFMHARunner->run(fmhaParams);
}
@ -1885,6 +1893,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
}
// Skip-softmax attention parameters
fmhaParams.skipSoftmaxThresholdScaleFactor = mSkipSoftmaxThresholdScaleFactorPrefill;
#ifdef SKIP_SOFTMAX_STAT
fmhaParams.skipSoftmaxTotalBlocks = mSkipSoftmaxTotalBlocks;
fmhaParams.skipSoftmaxSkippedBlocks = mSkipSoftmaxSkippedBlocks;
#else
if (tensorrt_llm::common::getEnvPrintSkipSoftmaxStat())
{
TLLM_THROW("To print skip softmax stat, please run build_wheel.py with -DSKIP_SOFTMAX_STAT");
}
#endif
if (mAttentionChunkSize)
{
fmhaParams.chunkedAttentionSize = *mAttentionChunkSize;

View File

@ -494,6 +494,14 @@ public:
// See [Chunked Attention] in _torch/modules/attention.py
std::optional<int64_t> mAttentionChunkSize = std::nullopt;
// Skip softmax threshold scale factor.
float mSkipSoftmaxThresholdScaleFactorPrefill = 0;
float mSkipSoftmaxThresholdScaleFactorDecode = 0;
#ifdef SKIP_SOFTMAX_STAT
uint32_t* mSkipSoftmaxTotalBlocks;
uint32_t* mSkipSoftmaxSkippedBlocks;
#endif
[[nodiscard]] auto data() const
{
return std::make_tuple(mLayerIdx, mNumHeads, mVisionStart, mVisionLength, mNumKVHeads, mHeadSize,
@ -510,7 +518,8 @@ public:
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1), mSkipSoftmaxThresholdScaleFactorPrefill,
mSkipSoftmaxThresholdScaleFactorDecode);
};
private:

View File

@ -43,7 +43,7 @@ template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
@ -63,7 +63,7 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

View File

@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
{
return common::getEnvAllReduceWorkspaceSize();
}
if (worldSize <= 2)
char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE");
if (envWorkspaceSize != nullptr)
{
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}
// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold)
inline std::unordered_map<int, std::unordered_map<int, std::pair<size_t, size_t>>> HeuristicThresholdLP{
{90,
{
{2, {4096, 4096 * 4096}},
{4, {4096, 1024 * 1024}},
{8, {2048, 512 * 512}},
}},
{100,
{
{2, {4096, 4096 * 4096}},
{4, {4096, 1024 * 2048}},
{8, {4096, 1024 * 1024}},
}},
};
inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op)
{
// The heuristic is based on the following assumptions:
// __________________________________
// | \ TWO-SHOT zone |
// | ONE-SHOT zone \ | NCCL zone
// |_______________________\______|___
// sm_major is 90 or 100
auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion()));
auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size];
auto const message_size = seq_len * hidden_size;
if (message_size >= two_shot_numel_threshold)
{
return AllReduceStrategyType::TWOSHOT;
}
else
{
return AllReduceStrategyType::ONESHOT;
return static_cast<size_t>(std::atoi(envWorkspaceSize));
}
return 67108864; // 64 MiB
}
// use 1D vector to store the best strategy instead of a map for each sm version

View File

@ -249,7 +249,7 @@ bool getEnvUseTileSizeKv64ForTrtllmGen()
bool getEnvEnablePDL()
{
static std::once_flag flag;
static bool enablePDL = false;
static bool enablePDL = true;
std::call_once(flag,
[&]()
@ -257,7 +257,18 @@ bool getEnvEnablePDL()
if (getSMVersion() >= 90)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
if (env)
{
if (env[0] == '1' && env[1] == '\0')
{
enablePDL = true;
}
else if (env[0] == '0' && env[1] == '\0')
{
enablePDL = false;
}
};
}
});
return enablePDL;
@ -281,6 +292,12 @@ bool getEnvUseNixlKvCache()
return useNixlKvCache;
}
bool getEnvUseMooncakeKvCache()
{
static bool const useMooncakeKvCache = getBoolEnv("TRTLLM_USE_MOONCAKE_KVCACHE");
return useMooncakeKvCache;
}
bool getEnvUseRoundRobinBlockDistForCP()
{
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
@ -343,6 +360,23 @@ std::string getEnvNixlBackend()
return nixlBackend;
}
std::string getEnvMooncakeInterface()
{
static std::once_flag flag;
static std::string mooncakeInterface;
std::call_once(flag,
[&]()
{
char const* mooncake_interface = std::getenv("TRTLLM_MOONCAKE_INTERFACE");
if (mooncake_interface)
{
mooncakeInterface = mooncake_interface;
}
});
return mooncakeInterface;
}
bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
@ -531,6 +565,11 @@ bool getEnvEplbForceGdrcopy()
return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY");
}
bool getEnvPrintSkipSoftmaxStat()
{
return getBoolEnv("TRTLLM_PRINT_SKIP_SOFTMAX_STAT");
}
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -83,8 +83,11 @@ inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 g
bool getEnvUseUCXKvCache();
bool getEnvUseMPIKvCache();
bool getEnvUseNixlKvCache();
bool getEnvUseMooncakeKvCache();
bool getEnvUseRoundRobinBlockDistForCP();
std::string getEnvUCXInterface();
@ -93,6 +96,8 @@ std::string getEnvNixlInterface();
std::string getEnvNixlBackend();
std::string getEnvMooncakeInterface();
bool getEnvDisaggLayerwise();
bool getEnvParallelCacheSend();
@ -156,6 +161,8 @@ bool getEnvKVCacheTransferAllBlocksForWindow();
bool getEnvEplbForceGdrcopy();
bool getEnvPrintSkipSoftmaxStat();
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,226 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ipUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <arpa/inet.h>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <unistd.h>
TRTLLM_NAMESPACE_BEGIN
namespace common
{
std::string getLocalIpByNic(std::string const& interface, int rank)
{
struct ifaddrs* ifaddr = nullptr;
if (getifaddrs(&ifaddr) == -1)
{
TLLM_LOG_ERROR(rank,
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is "
"set "
"correctly.");
return std::string{};
}
for (struct ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next)
{
if (ifa->ifa_addr == nullptr)
{
continue;
}
if (ifa->ifa_name == interface)
{
if (ifa->ifa_addr->sa_family == AF_INET)
{
char ip[INET_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in*>(ifa->ifa_addr))->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "0.0.0.0") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
else if (ifa->ifa_addr->sa_family == AF_INET6)
{
char ip[INET6_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in6*>(ifa->ifa_addr))->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
}
}
freeifaddrs(ifaddr);
TLLM_LOG_ERROR(
rank, "Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is set correctly.");
return std::string{};
}
std::string getLocalIpByHostname(int rank)
{
char hostname[256]{};
if (gethostname(hostname, sizeof(hostname)) == -1)
{
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
return std::string{};
}
struct addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_CANONNAME;
struct addrinfo* res = nullptr;
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
{
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
return std::string{};
}
for (struct addrinfo* p = res; p != nullptr; p = p->ai_next)
{
if (p->ai_family == AF_INET)
{ // IPv4
char ip[INET_ADDRSTRLEN]{};
struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
void* addr = &(ipv4->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "127.0.0.1") != 0
&& std::strcmp(ip, "0.0.0.0") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
else if (p->ai_family == AF_INET6)
{ // IPv6
char ip[INET6_ADDRSTRLEN]{};
struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
void* addr = &(ipv6->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
}
freeaddrinfo(res);
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
return std::string{};
}
std::string getLocalIpByRemoteOrHostName(int rank)
{
// Try IPv4
struct sockaddr_in addr
{
};
addr.sin_family = AF_INET;
addr.sin_port = htons(80);
// using google's public dns server to get the local ip which can be accessed from remote
char const* dns_ip_v4 = "8.8.8.8";
inet_pton(AF_INET, dns_ip_v4, &addr.sin_addr);
int sock = socket(AF_INET, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != -1)
{
socklen_t addr_len = sizeof(addr);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &addr_len) != -1)
{
char ip[INET_ADDRSTRLEN]{};
inet_ntop(AF_INET, &addr.sin_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}
// Try IPv6
struct sockaddr_in6 addr6
{
};
addr6.sin6_family = AF_INET6;
addr6.sin6_port = htons(80);
// using google's public dns server
char const* dns_ipv6 = "2001:4860:4860::8888";
inet_pton(AF_INET6, dns_ipv6, &addr6.sin6_addr);
sock = socket(AF_INET6, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr6), sizeof(addr6)) != -1)
{
socklen_t addr_len = sizeof(addr6);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr6), &addr_len) != -1)
{
char ip[INET6_ADDRSTRLEN]{};
inet_ntop(AF_INET6, &addr6.sin6_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}
// Try hostname
return getLocalIpByHostname(rank);
}
std::string getLocalIp(std::string interface, int rank)
{
std::string localIP = {};
if (!interface.empty())
{
localIP = getLocalIpByNic(interface, rank);
}
if (localIP.empty())
{
localIP = getLocalIpByRemoteOrHostName(rank);
}
// check whether the localIP is valid
if (localIP.empty())
{
TLLM_THROW("getLocalIp: Can't get local ip");
}
return localIP;
}
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,28 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/config.h"
#include <string>
TRTLLM_NAMESPACE_BEGIN
namespace common
{
std::string getLocalIp(std::string interface, int rank);
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -37,6 +37,46 @@ NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept
return instance;
}
NcclCommResourceManager::~NcclCommResourceManager()
{
// Mark that we're in destruction to prevent cleanup attempts from deleters
// that may run during static destruction
mIsDestroying.store(true, std::memory_order_release);
// Proactively clean up all resources before destruction
// This ensures cleanup happens in a controlled manner before static destruction
std::vector<std::pair<ncclComm_t, std::vector<ResourceEntry>>> allResources;
{
std::lock_guard<std::mutex> lock(mMutex);
// Move all resources out of the map
allResources.reserve(mCommResources.size());
for (auto& [comm, resources] : mCommResources)
{
allResources.emplace_back(comm, std::move(resources));
}
mCommResources.clear();
}
// Clean up all resources outside the lock
// Note: We don't call ncclCommDestroy here - that's the responsibility
// of the shared_ptr deleter. We just clean up registered resources.
for (auto& [comm, resources] : allResources)
{
for (auto& [cleanup, name] : resources)
{
try
{
cleanup();
}
catch (...)
{
// Ignore exceptions during destruction
}
}
}
}
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
{
if (!comm)
@ -60,23 +100,56 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
return;
}
// Check if we're in the process of being destroyed
// If so, skip cleanup - the destructor will handle it proactively
if (mIsDestroying.load(std::memory_order_acquire))
{
return;
}
std::vector<ResourceEntry> resourcesToClean;
{
std::lock_guard<std::mutex> lock(mMutex);
auto it = mCommResources.find(comm);
if (it == mCommResources.end())
// During static destruction, mutex and logging may not be safe.
// Use try-catch to handle any issues gracefully.
try
{
// Nothing registered for this comm, nothing to clean up
std::lock_guard<std::mutex> lock(mMutex);
// Double-check after acquiring lock (destruction may have started)
if (mIsDestroying.load(std::memory_order_acquire))
{
return;
}
auto it = mCommResources.find(comm);
if (it == mCommResources.end())
{
// Nothing registered for this comm, nothing to clean up
return;
}
// Move resources out (preserves order) and remove from map
resourcesToClean = std::move(it->second);
mCommResources.erase(it);
// Logging may fail during static destruction, so wrap in try-catch
try
{
TLLM_LOG_TRACE("[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(),
static_cast<void*>(comm));
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
catch (...)
{
// If mutex access fails during static destruction, just return.
// This prevents segfaults when the singleton is being destroyed.
return;
}
// Move resources out (preserves order) and remove from map
resourcesToClean = std::move(it->second);
mCommResources.erase(it);
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast<void*>(comm));
}
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
@ -85,19 +158,41 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
{
try
{
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
// Logging may fail during static destruction, so wrap in try-catch
try
{
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
}
catch (...)
{
// Ignore logging failures during static destruction
}
cleanup();
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(),
static_cast<void*>(comm), e.what());
try
{
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s",
name.c_str(), static_cast<void*>(comm), e.what());
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
catch (...)
{
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
name.c_str(), static_cast<void*>(comm));
try
{
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
name.c_str(), static_cast<void*>(comm));
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
}
}

View File

@ -26,6 +26,7 @@
#endif
#include <algorithm>
#include <atomic>
#include <functional>
#include <limits>
#include <memory>
@ -139,12 +140,13 @@ public:
private:
NcclCommResourceManager() = default;
~NcclCommResourceManager() = default;
~NcclCommResourceManager();
using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;
mutable std::mutex mMutex;
std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
std::atomic<bool> mIsDestroying{false};
};
// RAII helper to register a resource with a NCCL communicator.

View File

@ -123,13 +123,24 @@ std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
if (*comm)
{
// Clean up all registered resources FIRST
// The cleanupResources function uses a destruction guard to safely handle
// static destruction order issues - it will return early if the singleton
// is being destroyed (in which case the destructor handles cleanup proactively)
tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm);
// Now destroy the NCCL communicator
ncclResult_t result = ncclCommDestroy(*comm);
if (result != ncclSuccess)
{
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
// Logging may fail during static destruction, so wrap in try-catch
try
{
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
// Clear the communicator value before freeing the pointer

View File

@ -46,7 +46,7 @@ CUTLASS_DEVICE
void launch_dependent_grids()
{
#if (defined(CUTLASS_GDC_ENABLED))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
@ -57,7 +57,7 @@ CUTLASS_DEVICE
void wait_on_dependent_grids()
{
#if (defined(CUTLASS_GDC_ENABLED))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
}

View File

@ -686,4 +686,212 @@ public:
}
};
template <class Collective>
struct MixedInputUtilsSM100
{
private:
using KernelSchedule = typename Collective::KernelSchedule;
using ConversionMode = typename Collective::ConversionMode;
using SmemLayoutA = typename Collective::SmemLayoutA;
using SmemLayoutB = typename Collective::SmemLayoutB;
using ElementScale = typename Collective::ElementScale;
using ElementZero = typename Collective::ElementZero;
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
public:
// Helper functions to select packing for conversion
template <class SrcType, class DstType, int Cosize>
struct select_packing
{ // Naive packing policy
static constexpr auto value()
{
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
}
};
/// (Designed for separate transform pipeline in Blackwell)
/// Utilities to dequantize A.
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void dequantize_A_kblock_for_transform(Tensor<EngineIn, LayoutIn> const& tArA,
Tensor<EngineOut, LayoutOut>& tArACompute, cute::tuple<Ts...> const& partitioned_extra_info, int const k_block)
{
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
auto src = tArA(_, _, _, k_block);
auto dst = tArACompute(_, _, _, k_block);
auto pSrc = raw_pointer_cast(src.data());
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
constexpr int num_elements = decltype(size(src))::value;
constexpr int pack = decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
using Converter
= cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
constexpr int DstElementsPerReg = 32 / sizeof_bits_v<DstType>;
using RegArray = cutlass::AlignedArray<uint32_t, pack / DstElementsPerReg, sizeof(DstArray)>;
auto src_arr = recast<SrcArray>(src);
auto dst_arr = recast<DstArray>(dst);
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, pack));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert)
{
cute::transform(src_arr, dst_arr, Converter::convert);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
{
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
if constexpr (is_same_v<DstType, ElementScale>)
{
cute::transform(src_arr, dst_arr, Converter::convert);
using ScaleArray = cutlass::Array<ElementScale, pack>;
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
{
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
for (int i = 0; i < size<1>(dst_vm); ++i)
{
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
CUTLASS_PRAGMA_UNROLL
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
{
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
}
}
}
else
{
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
}
}
else
{
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
constexpr int pack = cute::gcd(pack1, pack2);
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
using StageArray = cutlass::Array<ElementScale, pack>;
constexpr int iters = num_elements / pack;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < iters; ++i)
{
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
StageArray stageArr;
stageArr = Converter1::convert(*pSrcArr);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < pack; ++j)
{
stageArr[j] = stageArr[j] * scales[i * pack + j];
}
*pDstArr = Converter2::convert(stageArr);
}
}
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
if constexpr (is_same_v<DstType, ElementZero>)
{
cute::transform(src_arr, dst_arr, Converter::convert);
using ScaleArray = cutlass::Array<ElementScale, pack>;
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
using ZeroArray = cutlass::Array<ElementZero, pack>;
auto zero_arr = recast<ZeroArray>(filter_zeros(zeros));
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
{
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, pack));
for (int i = 0; i < size<1>(dst_vm); ++i)
{
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
auto&& zero_reg = cute::recast<RegArray>(zeros_vm(_, i))(0);
CUTLASS_PRAGMA_UNROLL
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
{
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
bf16x2_val = __hadd2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(zero_reg[ii]));
}
}
}
else
{
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{});
}
}
else
{
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
constexpr int pack = cute::gcd(pack1, pack2);
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
using StageArray = cutlass::Array<ElementScale, pack>;
constexpr int iters = num_elements / pack;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < iters; ++i)
{
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
StageArray stageArr;
stageArr = Converter1::convert(*pSrcArr);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < pack; ++j)
{
stageArr[j] = stageArr[j] * scales[i * pack + j] + zeros[i * pack + j];
}
*pDstArr = Converter2::convert(stageArr);
}
}
}
else
{
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled for input partitioning.");
}
}
};
} // namespace cutlass::gemm::collective::detail

View File

@ -0,0 +1,294 @@
/*
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/collective/builders/sm100_common.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
int stages>
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(StageCount<stages> stage_count)
{
constexpr int Load2TransformStageCount = stages;
constexpr int Transform2MmaStageCount = stages;
constexpr int AccumulatorStageCount = stages;
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
}
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
int carveout_bytes>
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(
StageCountAutoCarveout<carveout_bytes> stage_count)
{
constexpr int CtaM = get<0>(CtaTileShape_MNK{});
constexpr int CtaN = get<1>(CtaTileShape_MNK{});
static_assert(CtaN <= 128, "Can't support CtaN>128 tiles");
constexpr int CtaK = get<2>(CtaTileShape_MNK{});
using AtomThrID = typename TiledMma::AtomThrID;
constexpr int TmemColumns = 512;
constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>;
constexpr bool IsAComputeinSmem = !IsAComputeinTmem;
// Detect 2x2 TMEM layout
constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN / 2 : CtaN;
constexpr int TmemAWordsPerDP = CtaK / 2;
constexpr int AccumulatorStageCount
= (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP);
constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32);
constexpr int TmemInAStageCount_Potential
= (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000;
// Mainload2Transform Pipeline
constexpr auto load2transform_pipeline_bytes
= sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>; // ElementA introduce here
constexpr auto s_bits = cute::is_void_v<ElementScale> ? 0 : cute::sizeof_bits_v<ElementScale>;
constexpr auto z_bits = cute::is_void_v<ElementZero> ? 0 : cute::sizeof_bits_v<ElementZero>;
constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage);
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>; // ElementB introduce here
constexpr int ab_stage_bytes
= cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
+ cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
+ cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
+ cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{}))
+ static_cast<int>(load2transform_pipeline_bytes) + static_cast<int>(load2mma_pipeline_bytes);
// Transform2Mma Pipeline
constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
constexpr auto a_compute_bits = cute::sizeof_bits_v<ElementAMma>;
constexpr int ab_compute_stage_bytes = cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem)
* size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
+ // If ACompute is in TMEM, Acompute buffer has 0 bytes.
static_cast<int>(transform2mma_pipeline_bytes);
constexpr int ABComputeStageCount_Potential
= SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes);
// The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount
constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential);
constexpr int SmemCapacityAfterABComputeCarveout
= SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes);
// Can we boost the number of buffers for A and B?
constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes;
static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2,
"Not enough SMEM or TMEM capacity for selected tile size");
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
}
} // namespace detail
// Mixed Input MMA kernels builder
template <class ElementAOptionalTuple, class GmemLayoutATagTuple, int AlignmentA, class ElementBOptionalTuple,
class GmemLayoutBTag, int AlignmentB, class ElementAccumulator,
class TileShape_MNK, // The Cluster-level TileShape
class ClusterShape_MNK, class StageCountType, class KernelScheduleType>
struct CollectiveBuilderSm100WeightOnly<arch::Sm100, arch::OpClassTensorOp,
ElementAOptionalTuple, // ElementA
GmemLayoutATagTuple, // LayoutA
AlignmentA,
ElementBOptionalTuple, // ElementB
GmemLayoutBTag, // LayoutB
AlignmentB, ElementAccumulator,
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int)
StageCountType, KernelScheduleType,
cute::enable_if_t<(cute::is_base_of_v<KernelScheduleSm100MixedInputGemm, KernelScheduleType>) &&(
(sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0)
&& ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>>
{
using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>;
using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>;
static constexpr cute::UMMA::Major UmmaMajorA
= cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
static constexpr cute::UMMA::Major UmmaMajorB
= cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>;
using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>;
using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>;
static constexpr bool NeitherIsTuple
= !cute::is_tuple<ElementAOptionalTuple>::value && !cute::is_tuple<ElementBOptionalTuple>::value;
static constexpr bool IsANarrow = cute::sizeof_bits_v<ElementA> < cute::sizeof_bits_v<ElementB>;
static constexpr bool IsMixedInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm.");
static_assert(
(cute::is_tuple<ElementAOptionalTuple>::value ^ cute::is_tuple<ElementBOptionalTuple>::value
|| (NeitherIsTuple && (cute::sizeof_bits<ElementA>::value != cute::sizeof_bits<ElementB>::value))),
"Either A OR B must be a tuple or the widths of A and B must be different.");
using ElementPairA = cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>,
ElementAOptionalTuple>;
using ElementPairB = cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>,
ElementBOptionalTuple>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
static_assert(IsATransformed, "A matrix should be transformed.");
// For fp32 types, map to tf32 MMA value type.
using ElementMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
using ElementAMma = ElementMma;
using ElementBMma = ElementMma;
static constexpr int IsSubbyteA = cute::sizeof_bits_v<ElementA> < 8;
using TmaElementA = cute::conditional_t<IsSubbyteA, uint8_t, ElementA>;
static constexpr int ScalingFactor = 1;
using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma<ElementAMma, ElementB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB, KernelScheduleType>());
using AtomThrID = typename TiledMma::AtomThrID;
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
using MmaShapeA_MK = decltype(partition_shape_A(
TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
using MmaShapeB_NK = decltype(partition_shape_B(
TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
using BlockTileA_M = decltype(cute::size<0, 0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
using BlockTileA_K = decltype(cute::size<0, 1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}));
// Input transform kernel can not use TMA 2SM instructions.
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA, ElementA,
BlockTileA_M, BlockTileA_K>());
using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA,
ElementAMma, BlockTileA_M, BlockTileA_K>());
using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomA,
SmemLayoutAtomACompute>;
static constexpr int MMA_M = cute::size<0, 0>(MmaShapeA_MK{});
using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>,
cute::conditional_t<
(UmmaMajorA == cute::UMMA::Major::K
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>),
cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x,
SM100_TMEM_STORE_32dp32b8x>, // TS Implementation
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>> // SS Implementation
>;
using BlockTileB_N = decltype(cute::size<0, 0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
using BlockTileB_K = decltype(cute::size<0, 1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
// Input transform kernel can not use TMA 2SM instructions.
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB, ElementB,
BlockTileB_N, BlockTileB_K>());
using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB,
ElementBMma, BlockTileB_N, BlockTileB_K>());
using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomB,
SmemLayoutAtomBCompute>;
using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementB>,
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementMma>>;
// Creating the stride of Transformed Input
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
using LayoutScale = cutlass::gemm::TagToStrideA_t<GmemLayoutScaleTag>;
using VoidShapeScale
= Shape<Shape<Int<128>, _1>, Shape<Int<64>, _1>, _1>; // Dummy Value to create a dummy ScaleConfig
using VoidStrideScale = Stride<Stride<_0, _1>, Stride<_0, _1>, _1>;
using VoidLayoutScale = Layout<VoidShapeScale, VoidStrideScale>;
using NonVoidLayoutScale = cute::conditional_t<cute::is_void_v<LayoutScale>, VoidLayoutScale, LayoutScale>;
using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{}));
// SmemCarveout
static constexpr int SchedulerPipelineStageCount = 3;
static constexpr bool IsArrayOfPointersGemm
= (cute::is_base_of_v<KernelScheduleSm100PtrArrayFastFP32Gemm, KernelScheduleType>);
// CLCPipeline = PipelineCLCFetchAsync
static constexpr auto CLCPipelineStorage
= sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
// CLC (scheduler) response
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
// CLC Throttle pipeline storage
static constexpr auto CLCThrottlePipelineStorage
= sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
// Tmem dealloc
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
// Tmem ptr storage
static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t);
// Tensormap Storage
static constexpr size_t TensorMapStorage
= IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
static constexpr auto KernelSmemCarveout = static_cast<int>(CLCPipelineStorage + CLCResponseStorage
+ CLCThrottlePipelineStorage + TmemDeallocStorage + TmemBasePtrsStorage + TensorMapStorage);
// Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations
static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int ScaleGranularityK = get_ScaleGranularityK<LayoutScale>();
static constexpr auto stage_info
= cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_weightonly<
Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB,
CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{});
static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info);
static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info);
static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info);
static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell");
using DispatchPolicy
= cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput<Load2TransformPipelineStageCount,
Transform2MmaPipelineStageCount, SchedulerPipelineStageCount, AccumulatorPipelineStageCount,
ClusterShape_MNK>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMmaSm100WeightOnly<DispatchPolicy, TileShape_MNK,
ElementPairA, StridePairA, ElementPairB, cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, TiledMma,
GmemTiledCopyA, SmemLayoutAtomPairA, CopyAtomPairA, cute::identity, GmemTiledCopyB, SmemLayoutAtomPairB,
CopyAtomPairB, cute::identity>;
};
} // namespace cutlass::gemm::collective

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp"
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType, class Enable = void>
struct CollectiveBuilderSm100WeightOnly
{
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB>
struct CollectiveMmaSm100WeightOnly
{
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -533,8 +533,8 @@ struct GemmFpAIntB
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 1000)
// Use SM80 implementation for GB10x, GB20x.
#elif (__CUDA_ARCH__ >= 1200)
// Use SM80 implementation for GB20x.
run_kernel<arch::Sm80>(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.

View File

@ -87,7 +87,9 @@ public:
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB<TypeA, uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
&& Arch::kMinComputeCapability != 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
@ -102,7 +104,9 @@ public:
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
&& Arch::kMinComputeCapability != 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
@ -116,6 +120,26 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -38,7 +38,13 @@ foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES})
if(FILE_EXT STREQUAL ".py")
# Read file content and replace module imports for Python files
file(READ ${SOURCE_FILE} _content)
string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content
string(REPLACE "from . import _C" "import tensorrt_llm.deep_gemm_cpp_tllm"
_content "${_content}")
string(REPLACE ".._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
"${_content}")
string(REPLACE "._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
"${_content}")
string(REPLACE "_C." "tensorrt_llm.deep_gemm_cpp_tllm." _content
"${_content}")
# Add adaptation header

View File

@ -90,4 +90,5 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
add_subdirectory(cache_transmission/ucx_utils)
add_subdirectory(cache_transmission/mooncake_utils)
add_subdirectory(cache_transmission/nixl_utils)

View File

@ -141,7 +141,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
NotificationInfo notificationInfo{syncInfo};
std::stringstream ss;
NotificationInfo::serialize(notificationInfo, ss);
status->wait();
TransferState transferState = status->wait();
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "AgentConnection::send failed");
// TODO: there is a bug in request_with_notify https://github.com/ai-dynamo/nixl/pull/252
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
}
@ -150,7 +151,7 @@ void AgentConnection::recv(DataContext const& ctx, void* data, size_t size) cons
{
NotificationSyncInfo syncInfo{mAgentName, ctx};
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo);
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo, ctx.getTransferTerminate());
}
void AgentConnection::sendRequestAndBufferInfo(batch_manager::RequestInfo& requestInfo,
@ -230,13 +231,13 @@ void AgentConnection::sendReadySignal(DataContext const& ctx, bool isReady) cons
bool AgentConnection::recvReadySignal(DataContext const& ctx) const
{
ReadySignalInfo readySignalInfo{mAgentName, ctx, false};
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo);
return true;
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo, ctx.getTransferTerminate());
return readySignalInfo.mIsReady;
}
AgentConnectionManager::AgentConnectionManager(
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
CacheState cacheState)
CacheState cacheState, std::string const& backendType)
: mCacheState(std::move(cacheState))
, mCacheTransBufferManagers(std::move(cacheTransBufferManagers))
, mRegMemDescs(MemoryType::kVRAM, {})
@ -246,8 +247,8 @@ AgentConnectionManager::AgentConnectionManager(
mAgentName = genUniqueAgentName();
// Create Agent
BaseAgentConfig config{mAgentName, true};
m_Agent = makeTransferAgent("nixl", &config);
BaseAgentConfig config{mAgentName, true, false, true, 1};
m_Agent = makeTransferAgent(backendType, &config);
TLLM_CHECK(!mCacheTransBufferManagers.empty());
std::vector<MemoryDesc> memDescs;
for (auto* cacheTransBufferManager : mCacheTransBufferManagers)
@ -315,9 +316,10 @@ AgentConnectionManager::AgentConnectionManager(
" ***** AgentConnectionManager::AgentConnectionManager mCommState: %s", mCommState.toString().c_str());
}
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo)
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag)
{
while (true)
while (!terminateFlag.load())
{
if (!mIsRunning)
{
@ -490,16 +492,16 @@ int AgentConnectionManager::getDeviceId() const
}
template <typename NotificationType>
void AgentConnectionManager::waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo)
void AgentConnectionManager::waitForNotification(
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag)
{
while (true)
while (!terminateFlag.load())
{
if (!mIsRunning)
{
return;
}
updateUnhandledNotifications();
std::scoped_lock lock(mNotificationMutex);
auto it = mUnhandledNotifications.begin();
@ -575,18 +577,20 @@ void AgentConnectionManager::waitForNotification(std::string const& remoteAgentN
// Explicit template instantiations
template void AgentConnectionManager::waitForNotification<NotificationSyncInfo>(
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo);
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
template void AgentConnectionManager::waitForNotification<ReadySignalInfo>(
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo);
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo)
void AgentConnectionManager::waitForSyncInfo(
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag)
{
waitForNotification(remoteAgentName, syncInfo);
waitForNotification(remoteAgentName, syncInfo, terminateFlag);
}
void AgentConnectionManager::waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo)
void AgentConnectionManager::waitForReadySignal(
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag)
{
waitForNotification(remoteAgentName, readySignalInfo);
waitForNotification(remoteAgentName, readySignalInfo, terminateFlag);
}
std::string const& AgentConnectionManager::getAgentName() const

View File

@ -277,12 +277,13 @@ class AgentConnectionManager : public ConnectionManager
public:
AgentConnectionManager(
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
CacheState cacheState);
CacheState cacheState, std::string const& backendType);
~AgentConnectionManager();
AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override;
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
[[nodiscard]] CommState const& getCommState() const override;
AgentConnection const* recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo);
AgentConnection const* recvConnectionAndRequestInfo(
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag);
[[nodiscard]] std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> const&
getCacheTransBufferManagers() const;
void updateUnhandledNotifications();
@ -293,9 +294,12 @@ public:
[[nodiscard]] std::string const& getAgentName() const;
template <typename NotificationType>
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
void waitForNotification(
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag);
void waitForSyncInfo(
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag);
void waitForReadySignal(
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag);
[[nodiscard]] bool isRunning() const override;
private:

View File

@ -107,9 +107,9 @@ TargetRanksInfo TargetRanksInfoForDP(
auto const peerCPNum = peerParConfig.mContextParallelism;
auto const selfCPNum = selfParConfig.mContextParallelism;
auto const selfTPRank = selfRank % selfTPNum;
auto const selfCPRank = selfRank % selfCPNum;
auto const selfTPRank = (selfRank % (selfTPNum * selfCPNum)) / selfCPNum;
auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;
int peerPPRankStart = 0;
int mDomainPPSize = 1;
@ -205,13 +205,14 @@ TargetRanksInfo TargetRanksInfoForDP(
}
std::vector<int> retRanks;
for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
for (int i = peerCPRankStart; i < peerCPRankEnd; i++)
{
for (int j = peerCPRankStart; j < peerCPRankEnd; j++)
for (int j = peerTPRankStart; j < peerTPRankEnd; j++)
{
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
{
int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
// Rank formula: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
int irank = (k * peerTPNum * peerCPNum) + (j * peerCPNum) + i;
retRanks.push_back(irank);
}
}

View File

@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
# Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this material and related documentation without an express
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
# MOONCAKE is not supported on Rocky8 for now
set(IS_ROCKY8 FALSE)
if(EXISTS "/etc/redhat-release")
set(IS_ROCKY8 TRUE)
endif()
if(MOONCAKE_ROOT AND NOT IS_ROCKY8)
find_library(TRANSFER_ENGINE_LIB transfer_engine ${MOONCAKE_ROOT}/lib)
find_path(TRANSFER_ENGINE_INCLUDE_DIR transfer_engine_c.h
${MOONCAKE_ROOT}/include)
message(STATUS "Find transfer engine results:")
message(STATUS " TRANSFER_ENGINE_LIB = ${TRANSFER_ENGINE_LIB}")
message(
STATUS " TRANSFER_ENGINE_INCLUDE_DIR = ${TRANSFER_ENGINE_INCLUDE_DIR}")
if(TRANSFER_ENGINE_LIB AND TRANSFER_ENGINE_INCLUDE_DIR)
set(MOONCAKE_WRAPPER_TARGET "tensorrt_llm_mooncake_wrapper")
add_library(${MOONCAKE_WRAPPER_TARGET} SHARED transferAgent.cpp)
target_compile_options(${MOONCAKE_WRAPPER_TARGET} PRIVATE -Wno-error)
target_include_directories(${MOONCAKE_WRAPPER_TARGET}
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
target_link_libraries(${MOONCAKE_WRAPPER_TARGET}
PRIVATE ${TRANSFER_ENGINE_LIB} CUDA::cudart)
# Export variables to parent scope for transfer_agent_binding
set(TRANSFER_ENGINE_INCLUDE_DIR
${TRANSFER_ENGINE_INCLUDE_DIR}
PARENT_SCOPE)
endif()
endif()

View File

@ -0,0 +1,612 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/cache_transmission/mooncake_utils/transferAgent.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/ipUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/transferAgent.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <algorithm>
#include <arpa/inet.h>
#include <chrono>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/file.h>
#include <sys/stat.h>
#include <thread>
#include <unistd.h>
namespace tensorrt_llm::executor::kv_cache
{
MooncakeTransferStatus::MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount)
: mEngine{engine}
, mBatchId{batchId}
, mRequestCount{requestCount}
{
TLLM_CHECK(mEngine);
}
TransferState MooncakeTransferStatus::wait(int64_t timeout_ms) const
{
auto startTime = std::chrono::steady_clock::now();
while (true)
{
if (mBatchFreed)
{
return TransferState::kSUCCESS;
}
bool has_failed = false;
bool all_completed = true;
for (size_t index = 0; index < mRequestCount; ++index)
{
transfer_status_t status;
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
if (rc || status.status == STATUS_FAILED)
{
has_failed = true;
if (rc)
{
TLLM_LOG_ERROR(
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
}
else
{
TLLM_LOG_ERROR(
"Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
}
}
else if (status.status != STATUS_COMPLETED)
{
all_completed = false;
}
}
// If any request failed, return failure
if (has_failed)
{
return TransferState::kFAILURE;
}
// If all requests completed successfully
if (all_completed)
{
freeBatchID(mEngine, mBatchId);
mBatchFreed = true;
TLLM_LOG_DEBUG("Batch ID %lu freed in wait()", mBatchId);
syncSegmentCache(mEngine);
return TransferState::kSUCCESS;
}
// If timeout_ms < 0, wait indefinitely
if (timeout_ms < 0)
{
std::this_thread::yield();
continue;
}
// Check if timeout has elapsed
auto elapsed
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
.count();
if (elapsed >= timeout_ms)
{
return TransferState::kIN_PROGRESS;
}
std::this_thread::yield();
}
}
[[nodiscard]] bool MooncakeTransferStatus::isCompleted() const
{
if (mBatchFreed)
{
return true;
}
bool has_failed = false;
for (size_t index = 0; index < mRequestCount; ++index)
{
transfer_status_t status;
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
if (rc || status.status == STATUS_FAILED)
{
has_failed = true;
if (rc)
{
TLLM_LOG_ERROR(
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
}
else
{
TLLM_LOG_ERROR("Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
}
}
else if (status.status == STATUS_PENDING || status.status == STATUS_WAITING)
{
TLLM_LOG_DEBUG("Transfer is pending for batch %lu, task %zu", mBatchId, index);
return false;
}
}
if (!has_failed)
{
// Each batchId has the batch size, and cannot process more requests
// than the batch size. So, free the batch id here to workaround the issue
// where the same batchId could be used to post multiple transfer.
freeBatchID(mEngine, mBatchId);
mBatchFreed = true;
TLLM_LOG_DEBUG("Batch ID %lu freed, future calls will return true directly", mBatchId);
}
// Currently, we cannot distinguish between failed and completed from return value.
TLLM_LOG_DEBUG("Transfer is completed for batch %lu", mBatchId);
return true;
}
std::string const MooncakeBase64Helper::STANDARD_CHARS
= "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
std::string MooncakeBase64Helper::encode(std::vector<uint8_t> const& data)
{
return encodeInternal(data, STANDARD_CHARS);
}
std::string MooncakeBase64Helper::encode(std::string const& data)
{
std::vector<uint8_t> vec(data.begin(), data.end());
return encode(vec);
}
std::vector<uint8_t> MooncakeBase64Helper::decode(std::string const& encoded)
{
return decodeInternal(encoded, STANDARD_CHARS);
}
std::string MooncakeBase64Helper::decodeToString(std::string const& encoded)
{
auto vec = decode(encoded);
return std::string(vec.begin(), vec.end());
}
std::string MooncakeBase64Helper::encodeInternal(std::vector<uint8_t> const& data, std::string const& chars)
{
std::string encoded;
size_t i = 0;
size_t j = 0;
std::array<uint8_t, 3> charArray3{};
std::array<uint8_t, 4> charArray4{};
size_t dataLen = data.size();
uint8_t const* bytes = data.data();
while (dataLen--)
{
charArray3[i++] = *(bytes++);
if (i == 3)
{
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
charArray4[3] = charArray3[2] & 0x3f;
for (i = 0; i < 4; i++)
{
encoded += chars[charArray4[i]];
}
i = 0;
}
}
if (i > 0)
{
for (j = i; j < 3; j++)
{
charArray3[j] = '\0';
}
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
charArray4[3] = charArray3[2] & 0x3f;
for (j = 0; j < i + 1; j++)
{
encoded += chars[charArray4[j]];
}
while (i++ < 3)
{
encoded += '=';
}
}
return encoded;
}
std::vector<uint8_t> MooncakeBase64Helper::decodeInternal(std::string const& encoded, std::string const& chars)
{
size_t encodedLen = encoded.size();
size_t i = 0;
size_t j = 0;
size_t in_ = 0;
std::array<uint8_t, 3> charArray3{};
std::array<uint8_t, 4> charArray4{};
std::vector<uint8_t> decoded;
std::string cleanEncoded;
for (char c : encoded)
{
if (!isWhitespace(c))
{
cleanEncoded += c;
}
}
encodedLen = cleanEncoded.size();
while (encodedLen-- && cleanEncoded[in_] != '=' && isBase64(cleanEncoded[in_], chars))
{
charArray4[i++] = cleanEncoded[in_];
in_++;
if (i == 4)
{
for (i = 0; i < 4; i++)
{
charArray4[i] = chars.find(charArray4[i]);
}
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
for (i = 0; i < 3; i++)
{
decoded.push_back(charArray3[i]);
}
i = 0;
}
}
if (i > 0)
{
for (j = i; j < 4; j++)
{
charArray4[j] = 0;
}
for (j = 0; j < 4; j++)
{
charArray4[j] = chars.find(charArray4[j]);
}
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
for (j = 0; j < i - 1; j++)
{
decoded.push_back(charArray3[j]);
}
}
return decoded;
}
bool MooncakeBase64Helper::isBase64(uint8_t c, std::string const& chars)
{
return (isalnum(c) || (c == chars[62]) || (c == chars[63]));
}
bool MooncakeBase64Helper::isWhitespace(uint8_t c)
{
return (c == ' ' || c == '\n' || c == '\r' || c == '\t');
}
MooncakeTransferAgent::MooncakeTransferAgent(BaseAgentConfig const& config)
{
mLocalAgentName = config.mName;
std::string segmentName = "127.0.0.1";
if (getenv("TLLM_MOONCAKE_IP_ADDR"))
{
segmentName = std::string(getenv("TLLM_MOONCAKE_IP_ADDR"));
}
else
{
auto ip = common::getLocalIp(common::getEnvMooncakeInterface(), mpi::MpiComm::session().getRank());
if (!ip.empty())
segmentName = ip;
}
mEngine = createTransferEngine("P2PHANDSHAKE", segmentName.c_str(), "", 0, true);
}
void MooncakeTransferAgent::registerMemory(RegisterDescs const& descs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::registerMemory");
std::lock_guard<std::mutex> lock(mMutex);
for (auto const& desc : descs.getDescs())
{
auto it = mMemRegInfo.find(desc.getAddr());
if (it != mMemRegInfo.end())
{
it->second->addRef();
continue;
}
int err = registerLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()), desc.getLen(), "*", 1);
TLLM_CHECK_WITH_INFO(err == 0, "registerLocalMemory failed, addr: %p, len: %lu",
reinterpret_cast<void*>(desc.getAddr()), desc.getLen());
auto mooncakeDesc = std::make_shared<MooncakeMemoryDesc>(desc);
mMemRegInfo[desc.getAddr()] = std::move(mooncakeDesc);
}
}
void MooncakeTransferAgent::deregisterMemory(RegisterDescs const& descs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::deregisterMemory");
std::lock_guard<std::mutex> lock(mMutex);
for (auto const& desc : descs.getDescs())
{
auto it = mMemRegInfo.find(desc.getAddr());
if (it != mMemRegInfo.end())
{
auto const& mooncakeDesc = it->second;
mooncakeDesc->releaseRef();
if (mooncakeDesc->getRefCount())
continue;
int err = unregisterLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()));
TLLM_CHECK_WITH_INFO(
err == 0, "unregisterLocalMemory failed, addr: %p", reinterpret_cast<void*>(desc.getAddr()));
mMemRegInfo.erase(desc.getAddr());
}
}
}
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::loadRemoteAgent");
// Do the same thing as loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
loadRemoteAgent(name, std::move(agentDesc.getBackendAgentDesc()));
}
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"MooncakeTransferAgent::loadRemoteAgent loadRemoteAgent to %s remoteagent name: %s", connectionInfo.c_str(),
name.c_str());
std::lock_guard<std::mutex> lock(mMutex);
auto segmentId = openSegment(mEngine, connectionInfo.c_str());
TLLM_CHECK_WITH_INFO(
segmentId >= 0, "loadRemoteAgent openSegment failed, connectionInfo: %s", connectionInfo.c_str());
mConnectedAgents[name].segmentId = segmentId;
}
void MooncakeTransferAgent::invalidateRemoteAgent(std::string const& name)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::invalidateRemoteAgent");
}
AgentDesc MooncakeTransferAgent::getLocalAgentDesc()
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalAgentDesc");
// Using connection info as agent desc
static size_t const kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalIpAndPort failed");
return AgentDesc{std::string(connectionInfo)};
}
ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalConnectionInfo");
static size_t const kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalConnectionInfo failed");
return std::string(connectionInfo);
}
[[nodiscard]] std::unique_ptr<TransferStatus> MooncakeTransferAgent::submitTransferRequests(
TransferRequest const& request)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::submitTransferRequests");
bool hasNotif = false;
std::string syncMessage;
if (request.getSyncMessage().has_value())
{
hasNotif = true;
syncMessage = request.getSyncMessage().value();
}
static size_t const kMaxRequestCount = 1024;
uint64_t batchId = allocateBatchID(mEngine, kMaxRequestCount);
TLLM_CHECK_WITH_INFO(batchId != INVALID_BATCH, "allocateBatchID failed");
int segmentId;
{
std::lock_guard<std::mutex> lock(mMutex);
std::string remoteName = request.getRemoteName();
auto it = mConnectedAgents.find(remoteName);
if (it == mConnectedAgents.end())
{
std::string error = "Remote agent " + remoteName + "not found";
TLLM_THROW(error);
}
auto const& agentInfo = it->second;
segmentId = agentInfo.segmentId;
}
auto localDescs = request.getSrcDescs().getDescs();
auto remoteDescs = request.getDstDescs().getDescs();
TLLM_CHECK_WITH_INFO(localDescs.size() == remoteDescs.size(), "Number of local and remote memory must match");
size_t requestCount = localDescs.size();
std::vector<transfer_request_t> transferRequests(requestCount);
for (size_t index = 0; index < requestCount; ++index)
{
TLLM_CHECK_WITH_INFO(
localDescs[index].getLen() == remoteDescs[index].getLen(), "Length of local and remote memory must match");
transferRequests[index].opcode = (request.getOp() == TransferOp::kREAD) ? OPCODE_READ : OPCODE_WRITE;
transferRequests[index].source = reinterpret_cast<void*>(localDescs[index].getAddr());
transferRequests[index].target_offset = remoteDescs[index].getAddr();
transferRequests[index].length = localDescs[index].getLen();
transferRequests[index].target_id = segmentId;
}
int rc = 0;
if (hasNotif)
{
notify_msg_t notifyMsg;
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
notifyMsg.msg = const_cast<char*>(syncMessage.c_str());
rc = submitTransferWithNotify(mEngine, batchId, transferRequests.data(), requestCount, notifyMsg);
}
else
{
rc = submitTransfer(mEngine, batchId, transferRequests.data(), requestCount);
}
TLLM_CHECK_WITH_INFO(rc == 0, "submitTransfer failed with status: %d", rc);
return std::make_unique<MooncakeTransferStatus>(mEngine, batchId, requestCount);
}
void MooncakeTransferAgent::notifySyncMessage(std::string const& name, SyncMessage const& syncMessage)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage");
int segmentId;
{
std::lock_guard<std::mutex> lock(mMutex);
auto it = mConnectedAgents.find(name);
if (it == mConnectedAgents.end())
{
TLLM_LOG_WARNING("Remote agent %s not found", name.c_str());
return;
}
auto const& agentInfo = it->second;
segmentId = agentInfo.segmentId;
}
notify_msg_t notifyMsg;
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
std::string encoded = MooncakeBase64Helper::encode(syncMessage);
notifyMsg.msg = const_cast<char*>(encoded.c_str());
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage notifyMsg.name: %s, notifyMsg.msg: %s", notifyMsg.name,
notifyMsg.msg);
int ret = genNotifyInEngine(mEngine, segmentId, notifyMsg);
TLLM_CHECK_WITH_INFO(ret == 0, "genNotifyInEngine failed with status: %d", ret);
}
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> MooncakeTransferAgent::getNotifiedSyncMessages()
{
std::unordered_map<std::string, std::vector<SyncMessage>> notifs;
int size = 0;
notify_msg_t* notifyMsgs = getNotifsFromEngine(mEngine, &size);
TLLM_CHECK_WITH_INFO(size >= 0, "getNotifsFromEngine returned negative size: %d", size);
for (int i = 0; i < size; i++)
{
if (notifyMsgs[i].msg == nullptr)
{
TLLM_LOG_WARNING("Message pointer is null for: %s", notifyMsgs[i].name);
continue;
}
std::string decoded = MooncakeBase64Helper::decodeToString(notifyMsgs[i].msg);
notifs[notifyMsgs[i].name].emplace_back(std::move(decoded));
TLLM_LOG_DEBUG("MooncakeTransferAgent::getNotifiedSyncMessages getNotifsFromEngine: %s, %s", notifyMsgs[i].name,
notifyMsgs[i].msg);
}
freeNotifsMsgBuf(notifyMsgs, size);
return notifs;
}
bool MooncakeTransferAgent::checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::checkRemoteDescs");
return true;
}
MooncakeTransferAgent::~MooncakeTransferAgent()
{
destroyTransferEngine(mEngine);
TLLM_LOG_DEBUG("MooncakeTransferAgent::~MooncakeTransferAgent");
}
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
extern "C"
{
std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config)
{
TLLM_CHECK(config);
return std::make_unique<MooncakeTransferAgent>(*config);
}
}
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -0,0 +1,165 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <atomic>
#include <mutex>
#include <thread>
#include <vector>
#include "tensorrt_llm/executor/transferAgent.h"
#include "transfer_engine_c.h"
namespace tensorrt_llm::executor::kv_cache
{
class MooncakeTransferStatus final : public TransferStatus
{
public:
MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount);
[[nodiscard]] bool isCompleted() const override;
TransferState wait(int64_t timeout_ms = -1) const override;
private:
transfer_engine_t mEngine;
uint64_t mBatchId;
size_t mRequestCount;
mutable bool mBatchFreed = false;
};
class MooncakeMemoryDesc
{
public:
MooncakeMemoryDesc(MemoryDesc desc)
: mDesc{std::move(desc)}
, mRefCnt{0}
{
}
MooncakeMemoryDesc(MooncakeMemoryDesc const& other)
: mDesc{other.mDesc}
, mRefCnt{0}
{
}
MooncakeMemoryDesc& operator=(MooncakeMemoryDesc const&) = delete;
~MooncakeMemoryDesc() = default;
void addRef() noexcept
{
++mRefCnt;
}
int releaseRef() noexcept
{
return --mRefCnt;
}
int getRefCount() const noexcept
{
return mRefCnt;
}
MemoryDesc const& getDesc() const noexcept
{
return mDesc;
}
private:
MemoryDesc mDesc;
int mRefCnt;
};
class MooncakeBase64Helper
{
public:
static std::string encode(std::vector<uint8_t> const& data);
static std::string encode(std::string const& data);
static std::vector<uint8_t> decode(std::string const& encoded);
static std::string decodeToString(std::string const& encoded);
private:
static const std::string STANDARD_CHARS;
static std::string encodeInternal(std::vector<uint8_t> const& data, std::string const& chars);
static std::vector<uint8_t> decodeInternal(std::string const& encoded, std::string const& chars);
static inline bool isBase64(uint8_t c, std::string const& chars);
static inline bool isWhitespace(uint8_t c);
};
class MooncakeTransferAgent final : public BaseTransferAgent
{
public:
MooncakeTransferAgent(BaseAgentConfig const& config);
~MooncakeTransferAgent();
void registerMemory(RegisterDescs const& descs) override;
void deregisterMemory(RegisterDescs const& descs) override;
void loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc) override;
void loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) override;
void invalidateRemoteAgent(std::string const& name) override;
AgentDesc getLocalAgentDesc() override;
ConnectionInfoType getLocalConnectionInfo() override;
[[nodiscard]] std::unique_ptr<TransferStatus> submitTransferRequests(TransferRequest const& request) override;
void notifySyncMessage(std::string const& name, SyncMessage const& syncMessage) override;
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> getNotifiedSyncMessages() override;
bool checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs) override;
private:
struct AgentInfo
{
int segmentId;
};
mutable std::mutex mMutex;
transfer_engine_t mEngine;
std::unordered_map<uintptr_t, std::shared_ptr<MooncakeMemoryDesc>> mMemRegInfo;
std::unordered_map<std::string, AgentInfo> mConnectedAgents;
std::string mLocalAgentName;
};
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
extern "C"
{
[[nodiscard]] std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config);
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -13,6 +13,9 @@
# License for the specific language governing permissions and limitations under
# the License.
# ============================================================================
# NIXL Wrapper Library
# ============================================================================
if(NIXL_ROOT)
find_package(NIXL REQUIRED)
# Check if all required packages were found
@ -30,6 +33,8 @@ if(NIXL_ROOT)
# Add include directories
target_include_directories(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
target_include_directories(${NIXL_WRAPPER_TARGET}
PRIVATE ${PROJECT_SOURCE_DIR}/include)
# Link against all NIXL libraries
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
@ -37,4 +42,85 @@ if(NIXL_ROOT)
# Link against CUDA
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE CUDA::cudart)
set(NIXL_ENABLED TRUE)
else()
set(NIXL_ENABLED FALSE)
endif()
# ============================================================================
# Check if Mooncake wrapper is available (built in mooncake_utils)
# ============================================================================
if(MOONCAKE_ROOT AND TARGET tensorrt_llm_mooncake_wrapper)
set(MOONCAKE_ENABLED TRUE)
else()
set(MOONCAKE_ENABLED FALSE)
endif()
# ============================================================================
# TensorRT-LLM Transfer Agent Binding Python Module Build if either NIXL or
# Mooncake is enabled
# ============================================================================
if(NIXL_ENABLED OR MOONCAKE_ENABLED)
set(TRANSFER_AGENT_BINDING_TARGET "tensorrt_llm_transfer_agent_binding")
# Collect binding source files
set(AGENT_BINDING_SOURCES "")
if(BINDING_TYPE STREQUAL "pybind")
list(APPEND AGENT_BINDING_SOURCES agentBindingsPybind.cpp)
else()
list(APPEND AGENT_BINDING_SOURCES agentBindingsNanobind.cpp)
endif()
if(BINDING_TYPE STREQUAL "pybind")
# Use pybind11 (already fetched via FetchContent)
pybind11_add_module(${TRANSFER_AGENT_BINDING_TARGET}
${AGENT_BINDING_SOURCES})
message(STATUS "Building tensorrt_llm_transfer_agent_binding with pybind11")
else()
# Default to nanobind (already fetched via FetchContent)
nanobind_add_module(${TRANSFER_AGENT_BINDING_TARGET}
${AGENT_BINDING_SOURCES})
message(STATUS "Building tensorrt_llm_transfer_agent_binding with nanobind")
endif()
target_compile_options(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE -Wno-error)
# Add common include directories
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${PROJECT_SOURCE_DIR}/include)
# Conditionally add NIXL support
if(NIXL_ENABLED)
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ENABLE_NIXL)
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE NIXL::nixl)
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${NIXL_WRAPPER_TARGET})
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE NIXL::nixl)
message(STATUS "Transfer agent binding: NIXL support enabled")
endif()
# Conditionally add Mooncake support
if(MOONCAKE_ENABLED)
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ENABLE_MOONCAKE)
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE tensorrt_llm_mooncake_wrapper)
message(STATUS "Transfer agent binding: Mooncake support enabled")
endif()
# Common dependencies
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE CUDA::cudart)
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${SHARED_TARGET})
# Set RPATH for the module to find wrapper libraries
set_target_properties(
${TRANSFER_AGENT_BINDING_TARGET}
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl"
INSTALL_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl")
endif()

View File

@ -0,0 +1,239 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/transferAgent.h"
#ifdef ENABLE_NIXL
#include "transferAgent.h"
#endif
#ifdef ENABLE_MOONCAKE
#include "../mooncake_utils/transferAgent.h"
#endif
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
namespace kvc = tensorrt_llm::executor::kv_cache;
NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
{
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (nanobind)";
// MemoryType enum
nb::enum_<kvc::MemoryType>(m, "MemoryType")
.value("DRAM", kvc::MemoryType::kDRAM)
.value("VRAM", kvc::MemoryType::kVRAM)
.value("BLK", kvc::MemoryType::kBLK)
.value("OBJ", kvc::MemoryType::kOBJ)
.value("FILE", kvc::MemoryType::kFILE);
// TransferOp enum
nb::enum_<kvc::TransferOp>(m, "TransferOp")
.value("READ", kvc::TransferOp::kREAD)
.value("WRITE", kvc::TransferOp::kWRITE);
// TransferState enum
nb::enum_<kvc::TransferState>(m, "TransferState")
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
.value("SUCCESS", kvc::TransferState::kSUCCESS)
.value("FAILURE", kvc::TransferState::kFAILURE);
// MemoryDesc class
nb::class_<kvc::MemoryDesc>(m, "MemoryDesc")
.def(nb::init<uintptr_t, size_t, uint32_t>(), nb::arg("addr"), nb::arg("len"), nb::arg("device_id"))
.def_prop_ro("addr", &kvc::MemoryDesc::getAddr)
.def_prop_ro("len", &kvc::MemoryDesc::getLen)
.def_prop_ro("device_id", &kvc::MemoryDesc::getDeviceId);
// MemoryDescs class
nb::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(nb::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), nb::arg("type"), nb::arg("descs"))
.def_prop_ro("type", &kvc::MemoryDescs::getType)
.def_prop_ro("descs", &kvc::MemoryDescs::getDescs);
// AgentDesc class
nb::class_<kvc::AgentDesc>(m, "AgentDesc")
.def(
"__init__",
[](kvc::AgentDesc* self, nb::bytes data)
{
std::string str(data.c_str(), data.size());
new (self) kvc::AgentDesc{std::move(str)};
},
nb::arg("backend_agent_desc"))
.def(nb::init<std::string>(), nb::arg("backend_agent_desc"))
.def_prop_ro("backend_agent_desc",
[](kvc::AgentDesc const& self)
{
auto const& desc = self.getBackendAgentDesc();
return nb::bytes(desc.data(), desc.size());
});
// TransferRequest class
nb::class_<kvc::TransferRequest>(m, "TransferRequest")
.def(nb::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
std::optional<kvc::SyncMessage>>(),
nb::arg("op"), nb::arg("src_descs"), nb::arg("dst_descs"), nb::arg("remote_name"),
nb::arg("sync_message") = std::nullopt)
.def_prop_ro("op", &kvc::TransferRequest::getOp)
.def_prop_ro("src_descs", &kvc::TransferRequest::getSrcDescs)
.def_prop_ro("dst_descs", &kvc::TransferRequest::getDstDescs)
.def_prop_ro("remote_name", &kvc::TransferRequest::getRemoteName)
.def_prop_ro("sync_message", &kvc::TransferRequest::getSyncMessage);
// TransferStatus base class
nb::class_<kvc::TransferStatus>(m, "TransferStatus")
.def("is_completed", &kvc::TransferStatus::isCompleted)
.def("wait", &kvc::TransferStatus::wait, nb::arg("timeout_ms") = -1);
// BaseAgentConfig struct
nb::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
.def(nb::init<>())
.def(
"__init__",
[](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread,
bool use_listen_thread, unsigned int num_workers) {
new (self) kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
},
nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false,
nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1)
.def_rw("name", &kvc::BaseAgentConfig::mName)
.def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers);
// BaseTransferAgent class (abstract base)
nb::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::BaseTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership)
.def(
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
#ifdef ENABLE_NIXL
// NixlTransferStatus class - release GIL for blocking operations
nb::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
.def("wait", &kvc::NixlTransferStatus::wait, nb::arg("timeout_ms") = -1,
nb::call_guard<nb::gil_scoped_release>());
// NixlTransferAgent class
nb::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::NixlTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
.def(
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
#endif
#ifdef ENABLE_MOONCAKE
// MooncakeTransferStatus class - release GIL for blocking operations
nb::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
.def("wait", &kvc::MooncakeTransferStatus::wait, nb::arg("timeout_ms") = -1,
nb::call_guard<nb::gil_scoped_release>());
// MooncakeTransferAgent class
nb::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::MooncakeTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, nb::arg("name"),
nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, nb::arg("name"),
nb::arg("memory_descs"));
#endif
// Factory function to create transfer agent by backend name (uses dynamic loading)
m.def(
"make_transfer_agent",
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
{ return kvc::makeTransferAgent(backend, &config).release(); },
nb::arg("backend"), nb::arg("config"), nb::rv_policy::take_ownership,
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
// Expose which backends are available
#ifdef ENABLE_NIXL
m.attr("NIXL_ENABLED") = true;
#else
m.attr("NIXL_ENABLED") = false;
#endif
#ifdef ENABLE_MOONCAKE
m.attr("MOONCAKE_ENABLED") = true;
#else
m.attr("MOONCAKE_ENABLED") = false;
#endif
}

View File

@ -0,0 +1,234 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/transferAgent.h"
#ifdef ENABLE_NIXL
#include "transferAgent.h"
#endif
#ifdef ENABLE_MOONCAKE
#include "../mooncake_utils/transferAgent.h"
#endif
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace kvc = tensorrt_llm::executor::kv_cache;
PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
{
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (pybind11)";
// MemoryType enum
py::enum_<kvc::MemoryType>(m, "MemoryType")
.value("DRAM", kvc::MemoryType::kDRAM)
.value("VRAM", kvc::MemoryType::kVRAM)
.value("BLK", kvc::MemoryType::kBLK)
.value("OBJ", kvc::MemoryType::kOBJ)
.value("FILE", kvc::MemoryType::kFILE);
// TransferOp enum
py::enum_<kvc::TransferOp>(m, "TransferOp")
.value("READ", kvc::TransferOp::kREAD)
.value("WRITE", kvc::TransferOp::kWRITE);
// TransferState enum
py::enum_<kvc::TransferState>(m, "TransferState")
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
.value("SUCCESS", kvc::TransferState::kSUCCESS)
.value("FAILURE", kvc::TransferState::kFAILURE);
// MemoryDesc class
py::class_<kvc::MemoryDesc>(m, "MemoryDesc")
.def(py::init<uintptr_t, size_t, uint32_t>(), py::arg("addr"), py::arg("len"), py::arg("device_id"))
.def_property_readonly("addr", &kvc::MemoryDesc::getAddr)
.def_property_readonly("len", &kvc::MemoryDesc::getLen)
.def_property_readonly("device_id", &kvc::MemoryDesc::getDeviceId);
// MemoryDescs class
py::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(py::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), py::arg("type"), py::arg("descs"))
.def_property_readonly("type", &kvc::MemoryDescs::getType)
.def_property_readonly("descs", &kvc::MemoryDescs::getDescs);
// AgentDesc class
py::class_<kvc::AgentDesc>(m, "AgentDesc")
.def(py::init(
[](py::bytes data)
{
std::string str(PyBytes_AsString(data.ptr()), PyBytes_Size(data.ptr()));
return kvc::AgentDesc{std::move(str)};
}),
py::arg("backend_agent_desc"))
.def(py::init<std::string>(), py::arg("backend_agent_desc"))
.def_property_readonly("backend_agent_desc",
[](kvc::AgentDesc const& self)
{
auto const& desc = self.getBackendAgentDesc();
return py::bytes(desc.data(), desc.size());
});
// TransferRequest class
py::class_<kvc::TransferRequest>(m, "TransferRequest")
.def(py::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
std::optional<kvc::SyncMessage>>(),
py::arg("op"), py::arg("src_descs"), py::arg("dst_descs"), py::arg("remote_name"),
py::arg("sync_message") = std::nullopt)
.def_property_readonly("op", &kvc::TransferRequest::getOp)
.def_property_readonly("src_descs", &kvc::TransferRequest::getSrcDescs)
.def_property_readonly("dst_descs", &kvc::TransferRequest::getDstDescs)
.def_property_readonly("remote_name", &kvc::TransferRequest::getRemoteName)
.def_property_readonly("sync_message", &kvc::TransferRequest::getSyncMessage);
// TransferStatus base class
py::class_<kvc::TransferStatus>(m, "TransferStatus")
.def("is_completed", &kvc::TransferStatus::isCompleted)
.def("wait", &kvc::TransferStatus::wait, py::arg("timeout_ms") = -1);
// BaseAgentConfig struct
py::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
.def(py::init<>())
.def(py::init(
[](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread,
unsigned int num_workers) {
return kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
}),
py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false,
py::arg("use_listen_thread") = false, py::arg("num_workers") = 1)
.def_readwrite("name", &kvc::BaseAgentConfig::mName)
.def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers);
// BaseTransferAgent class (abstract base)
py::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::BaseTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership)
.def(
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
#ifdef ENABLE_NIXL
// NixlTransferStatus class - release GIL for blocking operations
py::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
.def("wait", &kvc::NixlTransferStatus::wait, py::arg("timeout_ms") = -1,
py::call_guard<py::gil_scoped_release>());
// NixlTransferAgent class
py::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::NixlTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
.def(
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
#endif
#ifdef ENABLE_MOONCAKE
// MooncakeTransferStatus class - release GIL for blocking operations
py::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
.def("wait", &kvc::MooncakeTransferStatus::wait, py::arg("timeout_ms") = -1,
py::call_guard<py::gil_scoped_release>());
// MooncakeTransferAgent class
py::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::MooncakeTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, py::arg("name"),
py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, py::arg("name"),
py::arg("memory_descs"));
#endif
// Factory function to create transfer agent by backend name (uses dynamic loading)
m.def(
"make_transfer_agent",
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
{ return kvc::makeTransferAgent(backend, &config).release(); },
py::arg("backend"), py::arg("config"), py::return_value_policy::take_ownership,
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
// Expose which backends are available
#ifdef ENABLE_NIXL
m.attr("NIXL_ENABLED") = true;
#else
m.attr("NIXL_ENABLED") = false;
#endif
#ifdef ENABLE_MOONCAKE
m.attr("MOONCAKE_ENABLED") = true;
#else
m.attr("MOONCAKE_ENABLED") = false;
#endif
}

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <arpa/inet.h>
#include <chrono>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
@ -31,6 +32,7 @@
#include <set>
#include <sys/file.h>
#include <sys/stat.h>
#include <thread>
#include <unistd.h>
#include <vector>
@ -318,10 +320,40 @@ NixlTransferStatus::NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle)
TLLM_CHECK(mHandle);
}
void NixlTransferStatus::wait() const
TransferState NixlTransferStatus::wait(int64_t timeout_ms) const
{
while (!isCompleted())
;
auto startTime = std::chrono::steady_clock::now();
while (true)
{
auto status = mRawAgent->getXferStatus(mHandle);
if (status == NIXL_SUCCESS)
{
return TransferState::kSUCCESS;
}
else if (status != NIXL_IN_PROG)
{
return TransferState::kFAILURE;
}
// If timeout_ms < 0, wait indefinitely until status is not NIXL_IN_PROG
if (timeout_ms < 0)
{
std::this_thread::yield();
continue;
}
// Check if timeout has elapsed
auto elapsed
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
.count();
if (elapsed >= timeout_ms)
{
return TransferState::kIN_PROGRESS;
}
std::this_thread::yield();
}
}
[[nodiscard]] bool NixlTransferStatus::isCompleted() const
@ -333,6 +365,7 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
: mName{config.mName}
{
nixl_status_t status;
if (config.useListenThread)
{
FileLock lock("/tmp/trtllm_nixl_port.lock");
if (!lock.lock())
@ -341,10 +374,18 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
}
auto envPort = common::getEnvNixlPort();
uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort();
nixlAgentConfig nixlConfig{config.useProgThread, true, port};
nixlAgentConfig nixlConfig{
config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
mAddress = getAvailableIP() + ":" + std::to_string(port);
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
else
{
mAddress.clear();
nixlAgentConfig nixlConfig{
config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
std::string nixlBackend = common::getEnvNixlBackend();
// List of supported backends - extend this list as new backends are added
@ -645,7 +686,8 @@ void NixlLoopbackAgent::executeLoopbackRequest(
std::unique_ptr<TransferStatus> status = this->submitLoopbackRequests(memoryDescs, fileDescs, isOffload);
TLLM_CHECK_WITH_INFO(status != nullptr, "submitLoopbackRequests failed");
status->wait();
TransferState transferState = status->wait();
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "submitLoopbackRequests failed");
this->deregisterMemory(memoryDescs);
this->deregisterFiles(fileDescs);

View File

@ -45,7 +45,7 @@ public:
[[nodiscard]] bool isCompleted() const override;
void wait() const override;
[[nodiscard]] TransferState wait(int64_t timeout_ms = -1) const override;
private:
nixlAgent* mRawAgent{};

View File

@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio
auto req = item.request;
if (req->isDisaggContextCompleteState())
{
// If lastBlockId was tracked, unpin it. Otherwise, just terminate.
// If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate.
auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && item.lastBlockId.has_value())
if (kvMgr && !item.pinnedBlockIds.empty())
{
kvMgr->unpinBlocksById(item.lastBlockId.value());
kvMgr->unpinBlocksById(item.pinnedBlockIds);
}
else
{
@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses(
// move the in transmission requests to another tracker
if (llmReq->isDisaggContextTransmissionState())
{
std::optional<SizeType32> lastBlockId{};
std::vector<SizeType32> pinnedBlockIds{};
auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
{
lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
mModel->terminateRequest(llmReq);
}
inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId});
inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds});
}
finishedRequests.push_back(*it);
it = activeRequests.erase(it);

View File

@ -80,12 +80,12 @@ class Executor::Impl
using RequestList = std::list<LlmRequestPtr>;
// When block reuse is enabled for context worker for disaggregated serving,
// we need to store the last block id so that we can unpin the block when
// we need to store the pinned block ids so that we can unpin them when
// the request is finished.
struct InTransmissionItem
{
LlmRequestPtr request;
std::optional<SizeType32> lastBlockId;
std::vector<SizeType32> pinnedBlockIds;
};
using InTransList = std::list<InTransmissionItem>;

View File

@ -70,9 +70,9 @@ struct LamportComm
{
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
flag_value = *flag_ptr;
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
clear_size = *clear_ptr;
int data_offset = flag_value % 3;
int clear_offset = (flag_value + 2) % 3;
@ -88,7 +88,7 @@ struct LamportComm
}
}
__device__ __forceinline__ void update(int new_clear_size)
__device__ __forceinline__ void update(int64_t new_clear_size)
{
if (blockIdx.x == 0 && threadIdx.x == 0)
{
@ -103,10 +103,10 @@ struct LamportComm
int* counter_ptr;
int* flag_ptr;
int* clear_ptr;
int64_t* clear_ptr;
uint8_t* data_bufs[NRanks];
uint8_t* clear_buf;
int clear_size;
int64_t clear_size;
int flag_value;
};

View File

@ -21,18 +21,18 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels::ar_fusion
{
__global__ void lamport_initialize_kernel(float* ptr, int size)
__global__ void lamport_initialize_kernel(float* ptr, size_t size)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx >= size)
return;
ptr[idx] = -0.f;
}
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream)
{
int grid_size = (bytes + 127) / 128;
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
int grid_size = static_cast<int>((bytes + 1023) / 1024);
lamport_initialize_kernel<<<grid_size, 1024, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
}
Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
@ -45,10 +45,11 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
int device_id;
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
m_buffer_mgr = std::make_shared<tensorrt_llm::runtime::BufferManager>(m_cuda_stream);
int buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
int flag_size = tp_size * kBarrierFlagCount * sizeof(int);
int lamport_comm_size = tp_size * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
int lamport_buffer_size = 3 * lamport_comm_size;
size_t buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
size_t flag_size = tp_size * kBarrierFlagCount * sizeof(int);
size_t lamport_comm_size
= static_cast<size_t>(tp_size) * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
size_t lamport_buffer_size = 3 * lamport_comm_size;
for (auto size : {buffer_size, flag_size, lamport_buffer_size})
{
m_ipc_mem_handles.emplace_back(size, *m_buffer_mgr, m_world_config, p2p_supported);
@ -61,20 +62,20 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
workspace.push_back(ipc_mem_handle.getCommPtrs()[r]);
}
}
// atomic flag read counter
// kernel_flag_ptr[0] = 0;
// non-lamport flag
// kernel_flag_ptr[1] = 0;
// lamport flag
// kernel_flag_ptr[2] = 0;
// lamport triple buffer offset
// kernel_flag_ptr[3] = lamport_comm_size;
// lamport clear size
// kernel_flag_ptr[4] = 0;
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 5 * sizeof(int)));
std::vector<int> h_data{0, 0, 0, lamport_comm_size, 0};
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_data.data(), 5 * sizeof(int), cudaMemcpyHostToDevice));
// flag_buffer[0], atomic flag read counter
// flag_buffer[1], non-lamport flag
// flag_buffer[2], lamport flag
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 3 * sizeof(int)));
std::vector<int> h_flag_data{0, 0, 0};
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_flag_data.data(), 3 * sizeof(int), cudaMemcpyHostToDevice));
workspace.push_back(m_flag_d_ptr);
// layout_buffer[0], clear size for next lamport kernel
// layout_buffer[1], triple buffer offset for lamport kernel
TLLM_CUDA_CHECK(cudaMalloc(&m_layout_d_ptr, 2 * sizeof(int64_t)));
std::vector<int64_t> h_layout_data{0, static_cast<int64_t>(lamport_comm_size)};
TLLM_CUDA_CHECK(cudaMemcpy(m_layout_d_ptr, h_layout_data.data(), 2 * sizeof(int64_t), cudaMemcpyHostToDevice));
workspace.push_back(m_layout_d_ptr);
TLLM_CUDA_CHECK(cudaMalloc(&m_workspace, workspace.size() * sizeof(void*)));
TLLM_CUDA_CHECK(
cudaMemcpy(m_workspace, workspace.data(), workspace.size() * sizeof(void*), cudaMemcpyHostToDevice));
@ -87,6 +88,10 @@ Workspace::~Workspace()
{
TLLM_CUDA_CHECK(cudaFree(m_flag_d_ptr));
}
if (m_layout_d_ptr)
{
TLLM_CUDA_CHECK(cudaFree(m_layout_d_ptr));
}
if (m_workspace)
{
TLLM_CUDA_CHECK(cudaFree(m_workspace));

View File

@ -41,9 +41,10 @@ private:
void* m_workspace;
std::shared_ptr<tensorrt_llm::runtime::CudaStream> m_cuda_stream;
void* m_flag_d_ptr;
void* m_layout_d_ptr;
};
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream);
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream);
} // namespace kernels::ar_fusion
TRTLLM_NAMESPACE_END

View File

@ -230,59 +230,62 @@ inline __device__ __host__ T divUp(T m, T n)
// Return (block_size, cluster_size, loads_per_thread)
std::tuple<int, int, int> adjustGridConfig(int numTokens, int dim, int eltsPerThread)
{
// Start with preferred block_size and cluster_size
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
int clusterSize = 8;
#else
int clusterSize = 1;
#endif
static int SM = tensorrt_llm::common::getSMVersion();
int clusterSize = SM >= 90 ? 8 : 1;
int blockSize = 128;
// ========================== Adjust the grid configuration ==========================
int threadsNeeded = divUp(dim, eltsPerThread);
int loadsPerThread = 1;
blockSize = divUp(threadsNeeded, clusterSize);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
if (clusterSize > 1)
{
clusterSize /= 2;
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
{
clusterSize /= 2;
}
blockSize = divUp(threadsNeeded, clusterSize);
while (blockSize < 128 && clusterSize >= 2)
{
blockSize *= 2;
clusterSize /= 2;
}
int smCount = getMultiProcessorCount();
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
{
blockSize *= 2;
clusterSize /= 2;
}
}
blockSize = divUp(threadsNeeded, clusterSize);
while (blockSize < 128 && clusterSize >= 2)
{
blockSize *= 2;
clusterSize /= 2;
}
int smCount = getMultiProcessorCount();
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
{
blockSize *= 2;
clusterSize /= 2;
}
#endif
// Trying to scale up use multiple loads or CGA
while (blockSize > 1024)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
if (clusterSize < 8)
// Scale up with CGA if supported
if (SM >= 90)
{
clusterSize = clusterSize << 1;
if (clusterSize < 8)
{
clusterSize = clusterSize << 1;
}
else
{
break;
}
}
else
{
break;
if (loadsPerThread < 8)
{
loadsPerThread += 1;
}
else
{
break;
}
}
#else
if (loadsPerThread < 8)
{
loadsPerThread += 1;
}
else
{
break;
}
#endif
blockSize = divUp(threadsNeeded, clusterSize * loadsPerThread);
}
return {blockSize, clusterSize, loadsPerThread};
@ -420,9 +423,9 @@ __global__ void __launch_bounds__(1024) oneshotAllreduceFusionKernel(T* outputPt
}
float blockSum = blockReduceSum<float, true>(threadSum);
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
float fullSum = blockSum;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
int const numBlocks = cluster.num_blocks();
@ -459,6 +462,8 @@ using detail::adjustGridConfig;
void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
{
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
int const numTokens = params.numTokens;
int const tokenDim = params.tokenDim;
int const eltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
@ -466,38 +471,31 @@ void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
auto [blockSize, clusterSize, loadsPerThread] = adjustGridConfig(numTokens, tokenDim, eltsPerThread);
dim3 grid(numTokens, clusterSize, 1);
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1024 * 8 * eltsPerThread);
#else
1024 * eltsPerThread);
#endif
TLLM_LOG_DEBUG(
"[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
"loads_per_thread: %d, "
"threads_needed: %d",
numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, divUp(tokenDim, eltsPerThread));
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
1024 * (kSMVersion >= 90 ? 8 : 1) * eltsPerThread);
cudaLaunchAttribute attrs[2];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
attrs[1].id = cudaLaunchAttributeClusterDimension;
attrs[1].val.clusterDim.x = 1;
attrs[1].val.clusterDim.y = clusterSize;
attrs[1].val.clusterDim.z = 1;
#endif
cudaLaunchConfig_t config
{
.gridDim = grid, .blockDim = blockSize, .dynamicSmemBytes = 0, .stream = params.stream, .attrs = attrs,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
.numAttrs = 2,
#else
.numAttrs = 1,
#endif
cudaLaunchConfig_t config{
.gridDim = grid,
.blockDim = blockSize,
.dynamicSmemBytes = 0,
.stream = params.stream,
.attrs = attrs,
.numAttrs = kSMVersion >= 90 ? 2U : 1U,
};
#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, RMSNORM) \
@ -831,9 +829,9 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
float blockSum = blockReduceSum<float, true>(threadSum);
float fullSum = blockSum;
__shared__ float sharedVal[8];
// Use CGA Reduction if supported
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__shared__ float sharedVal[8];
int const numBlocks = cluster.num_blocks();
if (numBlocks > 1)
{
@ -876,6 +874,11 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
}
constexpr int kELTS_SIZE = sizeof(T_IN);
// Issue ACQBLK at the end. Assuming preceding kernel will not modify the buffer_flags.
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
// Update the buffer pointers
flag.waitAndUpdate({static_cast<uint32_t>(divUp<uint32_t>(numTokens, worldSize) * worldSize * dim * kELTS_SIZE),
static_cast<uint32_t>(numTokens * dim * kELTS_SIZE), 0, 0});
@ -883,6 +886,7 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
{
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
int const numTokens = params.numTokens;
int const tokenDim = params.tokenDim;
int const numEltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
@ -959,17 +963,13 @@ void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
rnConfig.attrs = rnAttrs;
rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
rnAttrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
#ifndef DISABLE_CGA
rnAttrs[1].id = cudaLaunchAttributeClusterDimension;
rnAttrs[1].val.clusterDim.x = 1;
rnAttrs[1].val.clusterDim.y = rnClusterSize;
rnAttrs[1].val.clusterDim.z = 1;
rnConfig.numAttrs = 2;
#else
rnConfig.numAttrs = 1;
#endif
rnConfig.numAttrs = (kSMVersion >= 90) ? 2U : 1U;
bool const rnUseCGA = rnClusterSize > 1;
bool const rnUseCGA = kSMVersion >= 90 && rnClusterSize > 1;
int const dimPadded = divUp(tokenDim, numEltsPerThread * rnNumThreads) * numEltsPerThread * rnNumThreads;
int const iters = dimPadded / rnNumThreads;

View File

@ -31,9 +31,9 @@ struct LamportComm
{
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
flag_value = *flag_ptr;
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
clear_size = *clear_ptr;
int data_offset = flag_value % 3;
int clear_offset = (flag_value + 2) % 3;
@ -49,7 +49,7 @@ struct LamportComm
}
}
__device__ __forceinline__ void update(int new_clear_size)
__device__ __forceinline__ void update(int64_t new_clear_size)
{
if (blockIdx.x == 0 && threadIdx.x == 0)
{
@ -64,10 +64,10 @@ struct LamportComm
int* counter_ptr;
int* flag_ptr;
int* clear_ptr;
int64_t* clear_ptr;
uint8_t* data_bufs[NRanks];
uint8_t* clear_buf;
int clear_size;
int64_t clear_size;
int flag_value;
};

View File

@ -48,6 +48,12 @@ namespace kernels::moe_comm
#define SWITCH_TOP_K(top_k, TOP_K, ...) \
switch (top_k) \
{ \
case 22: \
{ \
constexpr int TOP_K = 22; \
__VA_ARGS__; \
break; \
} \
case 16: \
{ \
constexpr int TOP_K = 16; \
@ -362,88 +368,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
int thread_idx = ThreadingPolicy::offset();
int local_token_idx = ThreadingPolicy::token_idx();
if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
}
// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;
uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
if (already_copied & (1ULL << target_rank))
// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}
uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
if (already_copied & (1ULL << target_rank))
{
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
}
continue;
}
// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
continue;
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();
// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
{
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();
// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
#pragma unroll
for (int k = 0; k < TOP_K; ++k)
{
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
for (int k = 0; k < TOP_K; ++k)
{
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
}
// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
payload_idx, ptrs, topk_target_ranks, topk_send_indices);
}
ThreadingPolicy::sync();
}
// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
ptrs, topk_target_ranks, topk_send_indices);
}
ThreadingPolicy::sync();
bool is_first_warp = threadIdx.x / warpSize == 0;
if (is_first_warp)
{
@ -452,8 +468,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
bool is_last_token = false;
if (lane_id == 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
if (local_num_tokens != 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
}
else
{
is_last_token = true;
}
}
is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
@ -523,7 +546,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
// Prepare kernel pointers struct
@ -568,6 +591,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
if (params.one_block_per_token)
{
int grid_size = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@ -577,6 +605,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
else
{
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@ -626,7 +659,70 @@ __device__ void vectorized_combine_impl(
// Load directly into the per-k accumulator; reduce across k below
acc[k].load(recv_buffer + base_token + offset);
}
if constexpr (TOP_K == 16)
// Reduce acc[TOP_K] into acc[0]
if constexpr (TOP_K == 22)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
T* a2 = reinterpret_cast<T*>(&acc[2]);
T* a3 = reinterpret_cast<T*>(&acc[3]);
T* a4 = reinterpret_cast<T*>(&acc[4]);
T* a5 = reinterpret_cast<T*>(&acc[5]);
T* a6 = reinterpret_cast<T*>(&acc[6]);
T* a7 = reinterpret_cast<T*>(&acc[7]);
T* a8 = reinterpret_cast<T*>(&acc[8]);
T* a9 = reinterpret_cast<T*>(&acc[9]);
T* a10 = reinterpret_cast<T*>(&acc[10]);
T* a11 = reinterpret_cast<T*>(&acc[11]);
T* a12 = reinterpret_cast<T*>(&acc[12]);
T* a13 = reinterpret_cast<T*>(&acc[13]);
T* a14 = reinterpret_cast<T*>(&acc[14]);
T* a15 = reinterpret_cast<T*>(&acc[15]);
T* a16 = reinterpret_cast<T*>(&acc[16]);
T* a17 = reinterpret_cast<T*>(&acc[17]);
T* a18 = reinterpret_cast<T*>(&acc[18]);
T* a19 = reinterpret_cast<T*>(&acc[19]);
T* a20 = reinterpret_cast<T*>(&acc[20]);
T* a21 = reinterpret_cast<T*>(&acc[21]);
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a1[j];
a2[j] += a3[j];
a4[j] += a5[j];
a6[j] += a7[j];
a8[j] += a9[j];
a10[j] += a11[j];
a12[j] += a13[j];
a14[j] += a15[j];
a16[j] += a17[j];
a18[j] += a19[j];
a20[j] += a21[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a2[j];
a4[j] += a6[j];
a8[j] += a10[j];
a12[j] += a14[j];
a16[j] += a18[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a4[j];
a8[j] += a12[j];
a16[j] += a20[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a8[j];
a0[j] += a16[j];
}
}
else if constexpr (TOP_K == 16)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
@ -710,9 +806,7 @@ __device__ void vectorized_combine_impl(
a0[j] += a8[j];
}
}
// Reduce acc[TOP_K] into acc[0]
if constexpr (TOP_K == 8)
else if constexpr (TOP_K == 8)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
@ -897,9 +991,19 @@ __global__ void moeA2ACombineKernel(
int local_token_idx = ThreadingPolicy::token_idx();
int const size_per_token = elements_per_token * sizeof(T);
if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;
}
#if !DISABLE_SYNC_FOR_PROFILING
@ -951,6 +1055,9 @@ __global__ void moeA2ACombineKernel(
__syncthreads();
#endif
if (local_num_tokens == 0)
return;
// Get output location for this token (using src_data_ptrs[0] as output)
T* token_output = static_cast<T*>(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token;
@ -1003,7 +1110,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.elements_per_token > 0);
// Configure kernel launch
@ -1011,6 +1118,15 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
int const kWarpsPerBlock = kBlockSize / 32; // warpSize
int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
int grid_size_block = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size_warp == 0)
{
grid_size_warp = 1;
}
if (grid_size_block == 0)
{
grid_size_block = 1;
}
// Prepare kernel pointers struct for combine
CombineKernelPointers kernel_ptrs = {}; // Zero-initialize

View File

@ -26,7 +26,7 @@ namespace kernels::moe_comm
{
// Configuration constants
static constexpr int kMaxTopK = 16; // Maximum top-k experts per token
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
static constexpr int kMaxRanks = 64; // Maximum supported EP size

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f6509dd36fb92554c6078595951a8de698d7bdaa07b9b817bfcdd255d4303bca
size 687070
oid sha256:4f1f3679968b8f6dea77f53534af9eb1348b6f476d4c3880833b41dd4cc9c803
size 687860

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b22d606e19b52047ae67319d61f138562f2b81df08ccde3f8fa04f040d408d7a
size 669688
oid sha256:a0d7061b400ab387309af00ae12f7a840b5abb91757183f415ca18329bbdb358
size 670478

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2a70e335677a1b0f9d98267fe7701735e42f105720403489276d48a4247ea1b5
size 423835
oid sha256:4a91ff0238b0c8f1d40f8441f22a60a2c64d344b8550de68737292ff449d1d7e
size 426203

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8289200bf78517033295966e9dbdf5c647da9aa7089669ff473ba436fef6a798
size 1230152
oid sha256:4d094c39dbdd372166facb297a4a91be80fb231bf3cca89afa97e61cc725f67e
size 1228572

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:97cc5f8d42d92332a92fa216847bbacccc7ef9f9d5208bd26585cd702d03fe57
size 1725040
oid sha256:1fe830d32459fd9a25d54e1d00a98720afd938d9e9042e2b5903f969e991d72d
size 1721882

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1264927817c08da144e387a7258f6c6fe424c0ff159f3ab0d6ffa3c4e3947598
size 375671
oid sha256:09af1ef9197c628c4a31cc58276ee6dcfad03f751069a78b5242594f93ea8c97
size 378039

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:950fb45e94ffc8e2ec9f5a4b682075be55cb85d6415b3eeb172ce2cf7d53220d
size 1140954
oid sha256:9e93bb514c30bc5a4cda8f402a386ab85d079f9b97aeff04788cf3c8a8cc87a6
size 1137008

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba97e1bf342788eaf74a78f542f870d3967214aed98b98600fae772aad5bad5f
size 653960
oid sha256:0dc47824dfc41004c5b243ce9f40eefeee15c69b88474e33ec13137ef56604e8
size 651592

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:337cc83d1880b1496e2f054285472b693c181e081819f425ddf2ea45a5dfe9f4
size 1130682
oid sha256:c0f042eabb29ee9db7ddf9791840337a7544653b295e4b2a5068b7f80bcd8251
size 1128314

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:859ffffa18f1c9c8068a1cfedec487c2e0eab84af2c3720eaa7bb2a044ea16f6
size 1534006
oid sha256:7a9d887dd0acea6d82a25e0dda908f4c5421eaa1ddbfeeb49d382c079156d67e
size 1535586

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:02bc55faacb50d0501c590ed11b40d802b374618cbde58db725cc67495762064
size 698136
oid sha256:22a7eaab8e44194acd83621e5546f164ad9cbeda8b67867f864a235036a03931
size 690242

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:510d6c9942dea4bef976c2307fc63f1d7341d78ad8b41cca3bf80bae0a377575
size 380847
oid sha256:e22fe2dde7f5542975db7517b37cdce0eaa656fed2bc58378b37a872c54a43ef
size 374533

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d0e0d34e15f533f756ac4ad6ef8889e5ed7556d859b6263509f608f2e7194e0a
size 964134

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6fd7941b92a10c3116b3d93b50ce94d90627ed020e1aa4263b2c46926db60250
size 1008328

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04439f4bdd5bf15dce0d59e455545236ed5b98c963a9b491c40d473eb766a04f
size 988580
oid sha256:ec624d7dceea5234b9dd4e43125f271e46ed4f2a4118837a23e00eb89571dcb2
size 985422

Some files were not shown because too many files have changed in this diff Show More