mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Compare commits
382 Commits
v1.2.0rc6.
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6df2c8a074 | ||
|
|
c1b0b7350f | ||
|
|
38296a472b | ||
|
|
50c78179dd | ||
|
|
55580f8ec1 | ||
|
|
7d16f3a28b | ||
|
|
bdaee87895 | ||
|
|
e291a834db | ||
|
|
04b112651b | ||
|
|
50c22b80d7 | ||
|
|
7d41475954 | ||
|
|
2967d299fb | ||
|
|
ba1cb6831d | ||
|
|
bbe535fddf | ||
|
|
ba1037ca4a | ||
|
|
48b09e5a25 | ||
|
|
18a33764b5 | ||
|
|
dacc881993 | ||
|
|
a1385243e1 | ||
|
|
9f044b9dd9 | ||
|
|
bf7998f1b8 | ||
|
|
11da7e3605 | ||
|
|
3bd319dc8e | ||
|
|
8e806abac3 | ||
|
|
c5914f9085 | ||
|
|
54459377d2 | ||
|
|
3a9a00b544 | ||
|
|
5e0dbba0c9 | ||
|
|
2de22f1a70 | ||
|
|
c0e25e5418 | ||
|
|
c5d5af9e7f | ||
|
|
7f018c89e9 | ||
|
|
8e0d20d901 | ||
|
|
80649a8b78 | ||
|
|
0371cbfd88 | ||
|
|
b2e2538fcd | ||
|
|
3c65ec3c55 | ||
|
|
f6045fac09 | ||
|
|
f6c4dd885f | ||
|
|
6ab996d635 | ||
|
|
ff7eb93f31 | ||
|
|
38f249b479 | ||
|
|
82dfef2e56 | ||
|
|
fdbdbba540 | ||
|
|
d80f01d205 | ||
|
|
7295af68ba | ||
|
|
1c69aad850 | ||
|
|
ced88424ef | ||
|
|
627d306df9 | ||
|
|
2b72d33fdc | ||
|
|
4632a8642d | ||
|
|
80f261ea36 | ||
|
|
78bb245554 | ||
|
|
4a09acd012 | ||
|
|
4c498bfe58 | ||
|
|
c5331e6dbb | ||
|
|
6fcd4e7099 | ||
|
|
5df03b2ea7 | ||
|
|
d707286ca8 | ||
|
|
afa55c12b6 | ||
|
|
56e779d09f | ||
|
|
4092a87b6f | ||
|
|
489dd60312 | ||
|
|
e0331297a6 | ||
|
|
c0ae6bbdbe | ||
|
|
6511dbaea0 | ||
|
|
bea61bb17d | ||
|
|
dc6b743fb6 | ||
|
|
43839c7d9b | ||
|
|
8d4b09dac6 | ||
|
|
22c81cb5fa | ||
|
|
f57aab5255 | ||
|
|
30f8455d29 | ||
|
|
342a47bf47 | ||
|
|
f8b2a8fd30 | ||
|
|
b85c447ceb | ||
|
|
09d9878385 | ||
|
|
81f878c279 | ||
|
|
d736c7f290 | ||
|
|
7187afe7b9 | ||
|
|
e8cceb06b2 | ||
|
|
b130d58c88 | ||
|
|
7e88212d24 | ||
|
|
872210468b | ||
|
|
dc32bac9fc | ||
|
|
cbf8357e5f | ||
|
|
be5579633e | ||
|
|
a34aa63685 | ||
|
|
3fec7e411c | ||
|
|
1fbadd2dde | ||
|
|
4a1b2e23b3 | ||
|
|
6095c80e56 | ||
|
|
bb2f883296 | ||
|
|
bb6a3973aa | ||
|
|
00355b24b7 | ||
|
|
77be1b7572 | ||
|
|
037753f65b | ||
|
|
6a4bebcd01 | ||
|
|
7d62773c6c | ||
|
|
704f58dfbe | ||
|
|
6507087c3f | ||
|
|
df0b976b99 | ||
|
|
ab58d7cac1 | ||
|
|
2eaabd7461 | ||
|
|
1e828587e5 | ||
|
|
5108a69fc0 | ||
|
|
998527724c | ||
|
|
810249c304 | ||
|
|
22a1d31a27 | ||
|
|
1b1058279c | ||
|
|
3e98265682 | ||
|
|
596d4f16fb | ||
|
|
617f728903 | ||
|
|
aa1fe931de | ||
|
|
46f035befe | ||
|
|
9cae7277ea | ||
|
|
6b8ae6fa81 | ||
|
|
77712ed4ab | ||
|
|
82aaf98070 | ||
|
|
8a04c05079 | ||
|
|
536a8f6a9c | ||
|
|
846e54aa09 | ||
|
|
3b56548fcf | ||
|
|
4e50cb5708 | ||
|
|
91ff46d418 | ||
|
|
7a2dab8e85 | ||
|
|
6b71b03947 | ||
|
|
ea380ff45c | ||
|
|
db2614ef10 | ||
|
|
bedfff4f00 | ||
|
|
e98c27ee4f | ||
|
|
225d3a9001 | ||
|
|
a792c23dcf | ||
|
|
3749a2ce1c | ||
|
|
b1733d56f6 | ||
|
|
4931c5eb3a | ||
|
|
d272f1a9bc | ||
|
|
2f768b76f8 | ||
|
|
c63fad7d96 | ||
|
|
e7a4486294 | ||
|
|
c04cf4334e | ||
|
|
0937df2c68 | ||
|
|
5a8bfcbb50 | ||
|
|
a7fe043b13 | ||
|
|
aaf80be0f3 | ||
|
|
5773a4d775 | ||
|
|
656c705ff1 | ||
|
|
b5a1e10bc0 | ||
|
|
da0830670a | ||
|
|
82c1ba84a7 | ||
|
|
0517b62789 | ||
|
|
8e2065b4d9 | ||
|
|
e2f5455533 | ||
|
|
a65b0d4efa | ||
|
|
c4f27fa4c0 | ||
|
|
afc533193d | ||
|
|
a4dcc6a711 | ||
|
|
6ba04eba06 | ||
|
|
71b4a8aa60 | ||
|
|
5bd37ce41e | ||
|
|
0d1f5ad7a2 | ||
|
|
c0b3c2b919 | ||
|
|
59045a0e41 | ||
|
|
865992b86b | ||
|
|
9e7b50aefb | ||
|
|
45ffbf1f21 | ||
|
|
937f8f78a1 | ||
|
|
bdf6953ddc | ||
|
|
f3dd6da080 | ||
|
|
5e0e48144f | ||
|
|
098251648d | ||
|
|
f631b25c85 | ||
|
|
4a1b742aa0 | ||
|
|
5845951538 | ||
|
|
4868772ad7 | ||
|
|
9f5b750a93 | ||
|
|
0b75340223 | ||
|
|
edbcff0257 | ||
|
|
ff836d4f41 | ||
|
|
1bbe71b3ed | ||
|
|
9085021aa4 | ||
|
|
84d107b2f0 | ||
|
|
0d2e2718ce | ||
|
|
a23c6f1092 | ||
|
|
464847c6be | ||
|
|
ef1d4a40b5 | ||
|
|
d944430f96 | ||
|
|
73870ae4ad | ||
|
|
827d12caaf | ||
|
|
910a633066 | ||
|
|
fdc03684cc | ||
|
|
fad000589d | ||
|
|
1e9c153b4c | ||
|
|
6c1abf2d45 | ||
|
|
ed3a3097a4 | ||
|
|
34c2fd50a9 | ||
|
|
1f3afb8e6f | ||
|
|
ec8a388c25 | ||
|
|
74832a1895 | ||
|
|
1f0365da36 | ||
|
|
6732c76414 | ||
|
|
fb05cd769a | ||
|
|
cce7247815 | ||
|
|
6accdbc6a6 | ||
|
|
0f4ed90560 | ||
|
|
692d8f2023 | ||
|
|
3e0344a53d | ||
|
|
48fee8d0f6 | ||
|
|
f396ad83b0 | ||
|
|
fa4c7997c5 | ||
|
|
4944192eae | ||
|
|
966231d29c | ||
|
|
965578ca21 | ||
|
|
9cee32ab39 | ||
|
|
2f8d6d25a8 | ||
|
|
223411e988 | ||
|
|
270be801aa | ||
|
|
c59aa8bec5 | ||
|
|
ae6d5766ed | ||
|
|
55bc6a5ff8 | ||
|
|
ee07a7c55e | ||
|
|
1865020b6f | ||
|
|
93ac0bc1dc | ||
|
|
27976fce9c | ||
|
|
55f3cda66d | ||
|
|
c04563657e | ||
|
|
d70aeddc7f | ||
|
|
684b37df02 | ||
|
|
c5b0f9e436 | ||
|
|
bfc591994c | ||
|
|
4a5ef84dc2 | ||
|
|
14554ab3f3 | ||
|
|
819d03fa88 | ||
|
|
13ffe52ad0 | ||
|
|
f3f02315df | ||
|
|
db3430f589 | ||
|
|
7e4cef9def | ||
|
|
d8b5aeb061 | ||
|
|
46e4af5688 | ||
|
|
fe12faef81 | ||
|
|
cd5cd60ee4 | ||
|
|
8462cf6c96 | ||
|
|
97b38ac403 | ||
|
|
0ecdb69b93 | ||
|
|
53b81783b1 | ||
|
|
83e02ee335 | ||
|
|
182b3eb633 | ||
|
|
1d01214ff0 | ||
|
|
4ae6f6a46c | ||
|
|
7395ca93b6 | ||
|
|
c059e6caa1 | ||
|
|
a9eb5afc9f | ||
|
|
1f8ed71d5f | ||
|
|
16fd781e42 | ||
|
|
43178590d1 | ||
|
|
c4b36d31ff | ||
|
|
8614cd3439 | ||
|
|
e2891a6c77 | ||
|
|
ddac4d7379 | ||
|
|
69152c4e7c | ||
|
|
56ef97e06e | ||
|
|
ecea71ca7a | ||
|
|
f4f0fe85e9 | ||
|
|
534700ecd9 | ||
|
|
595daa5089 | ||
|
|
156f6453dc | ||
|
|
f6c3bc16b9 | ||
|
|
7b84e48e0f | ||
|
|
68cf5c7924 | ||
|
|
fc1f77eafc | ||
|
|
8c1cfc872b | ||
|
|
92d90fa29a | ||
|
|
0027a01ad5 | ||
|
|
06900a7f19 | ||
|
|
984c20e0b2 | ||
|
|
e284d0bf80 | ||
|
|
64bb1a5155 | ||
|
|
8408c40d8b | ||
|
|
871c6b435c | ||
|
|
522f1d2bc3 | ||
|
|
f2e00a75de | ||
|
|
3ddc9d2b48 | ||
|
|
48c875f8ea | ||
|
|
cc1323be24 | ||
|
|
59b05dc0a8 | ||
|
|
53db3b2612 | ||
|
|
77b591f73b | ||
|
|
d691371eaf | ||
|
|
5bc7ffe379 | ||
|
|
18f8b22956 | ||
|
|
621156ad44 | ||
|
|
1e82ff7a0c | ||
|
|
696f754ef4 | ||
|
|
648196f8ae | ||
|
|
f05af48bca | ||
|
|
0d2500c631 | ||
|
|
ccc64da287 | ||
|
|
12e1cb8d7e | ||
|
|
aaa87abf41 | ||
|
|
ba14a9308e | ||
|
|
0f308e95f9 | ||
|
|
a6a88985cf | ||
|
|
472fe497dc | ||
|
|
ea6cd76c55 | ||
|
|
c87f1a6b39 | ||
|
|
9e9523c3cc | ||
|
|
7421224d69 | ||
|
|
d30ee8101e | ||
|
|
237fd0eae4 | ||
|
|
f8501f3cc8 | ||
|
|
f0bd60a395 | ||
|
|
066b653940 | ||
|
|
2f139ee07e | ||
|
|
914dd39127 | ||
|
|
d274a4c5d3 | ||
|
|
5549067966 | ||
|
|
5266475014 | ||
|
|
4fc6036276 | ||
|
|
cd4b4f43fa | ||
|
|
5a611cb8f5 | ||
|
|
aa5dbb7ca5 | ||
|
|
5ae154022a | ||
|
|
b15f987972 | ||
|
|
a66eeab537 | ||
|
|
dcd3f7b5ea | ||
|
|
6c76148b56 | ||
|
|
77e37d9dd0 | ||
|
|
2ce785f39a | ||
|
|
21a93fbf9d | ||
|
|
3f25db9d3e | ||
|
|
3b3069b390 | ||
|
|
e75331480f | ||
|
|
7c82605327 | ||
|
|
bee9051484 | ||
|
|
20b69a982a | ||
|
|
5489d188a4 | ||
|
|
b882393d69 | ||
|
|
dfa11d810e | ||
|
|
7b71ff6b8a | ||
|
|
27e49e2904 | ||
|
|
9f6abaf59f | ||
|
|
7b51e3cedb | ||
|
|
dd8ce68c94 | ||
|
|
ac03915dc3 | ||
|
|
31bc14b350 | ||
|
|
52cee573ad | ||
|
|
cb0444b1b5 | ||
|
|
356ad4fe3a | ||
|
|
70b4d282c6 | ||
|
|
48dbc61129 | ||
|
|
478b6b20a1 | ||
|
|
72c5480dfb | ||
|
|
00f70c30a6 | ||
|
|
9aa40871c2 | ||
|
|
a7ac5a6bca | ||
|
|
9f283f330b | ||
|
|
e0b2a94309 | ||
|
|
2e88c86f10 | ||
|
|
bd5b3c2ac0 | ||
|
|
91a9ae42d2 | ||
|
|
799a2ae311 | ||
|
|
a97e411b44 | ||
|
|
f02782a6f2 | ||
|
|
6fe89ea00f | ||
|
|
0b279f4ad4 | ||
|
|
4e55b83101 | ||
|
|
3b4f26e4d1 | ||
|
|
df15be3fad | ||
|
|
9d7e038bcb | ||
|
|
33a90f2dd2 | ||
|
|
bec864a78c | ||
|
|
897a38978d | ||
|
|
601c29ca73 | ||
|
|
76ec820465 | ||
|
|
cfe53e7425 | ||
|
|
4a98f190a8 | ||
|
|
c1cfb61b1b | ||
|
|
50c2b82f24 | ||
|
|
27064f95c7 | ||
|
|
5da7879b38 | ||
|
|
22c6e8a424 | ||
|
|
cb5cd4376e |
19
.github/CODEOWNERS
vendored
19
.github/CODEOWNERS
vendored
@ -1,5 +1,18 @@
|
||||
# This file defines code ownership rules for the repository.
|
||||
|
||||
## TensorRT-LLM QA
|
||||
### Integration Tests
|
||||
/tests/integration/test_lists/qa @NVIDIA/trt-llm-qa
|
||||
/tests/integration/defs/examples/test_ray.py @NVIDIA/trt-llm-qa-function
|
||||
/tests/integration/defs/examples/test_redrafter.py @NVIDIA/trt-llm-qa-function
|
||||
/tests/integration/defs/accuracy @NVIDIA/trt-llm-qa-function
|
||||
/tests/integration/defs/stress_test @NVIDIA/trt-llm-qa-function
|
||||
/tests/integration/defs/triton_server @NVIDIA/trt-llm-qa-function
|
||||
/tests/integration/defs/test_e2e.py @NVIDIA/trt-llm-qa-function
|
||||
/tests/integration/defs/disaggregated @NVIDIA/trt-llm-qa-serving
|
||||
/tests/integration/defs/sysinfo @NVIDIA/trt-llm-qa-perf
|
||||
/tests/integration/defs/perf @NVIDIA/trt-llm-qa-perf
|
||||
/tests/integration/defs/perf/disagg @NVIDIA/trt-llm-qa-serving
|
||||
|
||||
## TensorRT-LLM Infra
|
||||
### CI
|
||||
@ -13,6 +26,11 @@
|
||||
|
||||
## TensorRT-LLM - Docs
|
||||
/docs @NVIDIA/trt-llm-doc-owners
|
||||
/CODING_GUIDELINES.md @NVIDIA/trt-llm-doc-owners
|
||||
/CODE_OF_CONDUCT.md @NVIDIA/trt-llm-doc-owners
|
||||
/CONTAINER_SOURCE.md @NVIDIA/trt-llm-doc-owners
|
||||
/CONTRIBUTING.md @NVIDIA/trt-llm-doc-owners
|
||||
/README.md @NVIDIA/trt-llm-doc-owners
|
||||
|
||||
## Examples
|
||||
/examples @NVIDIA/trt-llm-doc-owners
|
||||
@ -183,6 +201,7 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
||||
## and license compliance when adding, removing, or changing versions of dependencies.
|
||||
### License Files
|
||||
/LICENSE @NVIDIA/trt-llm-oss-compliance
|
||||
/ATTRIBUTIONS-*.md @NVIDIA/trt-llm-oss-compliance
|
||||
/jenkins/license_cpp.json @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance
|
||||
|
||||
### Python Dependency Management
|
||||
|
||||
4
.github/workflows/auto-assign.yml
vendored
4
.github/workflows/auto-assign.yml
vendored
@ -11,10 +11,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Get assignee
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
id: get-assignee
|
||||
with:
|
||||
github-token: ${{secrets.GITHUB_TOKEN}}
|
||||
|
||||
@ -14,7 +14,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'Issue has not received an update in over 14 days. Adding stale label.'
|
||||
|
||||
7
.github/workflows/blossom-ci.yml
vendored
7
.github/workflows/blossom-ci.yml
vendored
@ -53,6 +53,7 @@ jobs:
|
||||
"amukkara",
|
||||
"anish-shanbhag",
|
||||
"arekay",
|
||||
"arysef",
|
||||
"atrifex",
|
||||
"Autumn1998",
|
||||
"baize97",
|
||||
@ -121,6 +122,7 @@ jobs:
|
||||
"heyuhhh",
|
||||
"hijkzzz",
|
||||
"hlu1",
|
||||
"hnover-nv",
|
||||
"HuiGao-NV",
|
||||
"hvagadia",
|
||||
"hypdeb",
|
||||
@ -154,6 +156,7 @@ jobs:
|
||||
"kaiyux",
|
||||
"kanghui0204",
|
||||
"karljang",
|
||||
"karthikvetrivel",
|
||||
"katec846",
|
||||
"Kefeng-Duan",
|
||||
"KingsleyLiu-NV",
|
||||
@ -191,6 +194,7 @@ jobs:
|
||||
"mlefeb01",
|
||||
"moraxu",
|
||||
"MrGeva",
|
||||
"mzweilz",
|
||||
"Naveassaf",
|
||||
"nekorobov",
|
||||
"netanel-haber",
|
||||
@ -215,6 +219,7 @@ jobs:
|
||||
"omera-nv",
|
||||
"pamelap-nvidia",
|
||||
"pcastonguay",
|
||||
"pcicotti",
|
||||
"pdrake-nv",
|
||||
"peaceh-nv",
|
||||
"pengbowang-nv",
|
||||
@ -243,6 +248,7 @@ jobs:
|
||||
"schetlur-nv",
|
||||
"shaharmor98",
|
||||
"shangz-ai",
|
||||
"sherry-1001",
|
||||
"shifangx",
|
||||
"Shixiaowei02",
|
||||
"Shunkangz",
|
||||
@ -262,6 +268,7 @@ jobs:
|
||||
"syuoni",
|
||||
"Tabrizian",
|
||||
"talorabr",
|
||||
"taylor-yb-lee",
|
||||
"tburt-nv",
|
||||
"tcherckez-nvidia",
|
||||
"thorjohnsen",
|
||||
|
||||
2
.github/workflows/bot-command.yml
vendored
2
.github/workflows/bot-command.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Add bot help comment
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const helpMessage = "" +
|
||||
|
||||
4
.github/workflows/l0-test.yml
vendored
4
.github/workflows/l0-test.yml
vendored
@ -34,7 +34,7 @@ jobs:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
steps:
|
||||
- name: Update commit status
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
state = 'pending'
|
||||
@ -60,7 +60,7 @@ jobs:
|
||||
with:
|
||||
paths: results/**/results*.xml
|
||||
- name: Update commit status
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
github.rest.repos.createCommitStatus({
|
||||
|
||||
4
.github/workflows/label_community_pr.yml
vendored
4
.github/workflows/label_community_pr.yml
vendored
@ -17,10 +17,10 @@ jobs:
|
||||
if: github.repository == 'NVIDIA/TensorRT-LLM'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
|
||||
2
.github/workflows/label_issue.yml
vendored
2
.github/workflows/label_issue.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout private action repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
repository: NVIDIA/goggles_action
|
||||
path: ./.github/actions/goggles_action # local path to store the action
|
||||
|
||||
4
.github/workflows/pr-check.yml
vendored
4
.github/workflows/pr-check.yml
vendored
@ -59,10 +59,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
4
.github/workflows/precommit-check.yml
vendored
4
.github/workflows/precommit-check.yml
vendored
@ -29,11 +29,11 @@ jobs:
|
||||
name: Pre-commit Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.ref || github.ref }}
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip'
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -40,6 +40,8 @@ tensorrt_llm/libs
|
||||
tensorrt_llm/bindings.*.so
|
||||
tensorrt_llm/bindings.pyi
|
||||
tensorrt_llm/bindings/**/*.pyi
|
||||
tensorrt_llm/tensorrt_llm_transfer_agent_binding.*.so
|
||||
tensorrt_llm/tensorrt_llm_transfer_agent_binding.pyi
|
||||
tensorrt_llm/deep_ep/
|
||||
tensorrt_llm/deep_ep_cpp_tllm.*.so
|
||||
tensorrt_llm/deep_ep_cpp_tllm.pyi
|
||||
@ -56,13 +58,14 @@ tensorrt_llm/scripts
|
||||
docs/source/**/*.rst
|
||||
!docs/source/examples/index.rst
|
||||
!docs/source/deployment-guide/config_table.rst
|
||||
!docs/source/deployment-guide/note_sections.rst
|
||||
!docs/source/_includes/note_sections.rst
|
||||
*.swp
|
||||
|
||||
# Testing
|
||||
.coverage.*
|
||||
results_trt/
|
||||
llm-test-workspace/
|
||||
ad-test-workspace/
|
||||
|
||||
# build/debug
|
||||
*.safetensors
|
||||
@ -76,6 +79,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
|
||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
|
||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
|
||||
.devcontainer/.env
|
||||
/examples/layer_wise_benchmarks/autotuner_cache/
|
||||
/examples/layer_wise_benchmarks/profiles/
|
||||
|
||||
# User config files
|
||||
|
||||
4
3rdparty/CMakeLists.txt
vendored
4
3rdparty/CMakeLists.txt
vendored
@ -38,8 +38,8 @@ FetchContent_Declare(
|
||||
|
||||
FetchContent_Declare(
|
||||
deepgemm
|
||||
GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM
|
||||
GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch
|
||||
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
|
||||
GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch
|
||||
GIT_SUBMODULES_RECURSE
|
||||
ON
|
||||
SOURCE_SUBDIR
|
||||
|
||||
@ -487,9 +487,17 @@ else:
|
||||
f.read()
|
||||
```
|
||||
|
||||
## Documentation Guidelines
|
||||
|
||||
#### CLI Options in Documentation
|
||||
1. When documenting CLI commands for `trtllm-serve`, `trtllm-bench`, `trtllm-eval`, or similar tools, prefer using `--config` over `--extra_llm_api_options` for specifying configuration files.
|
||||
- `--config` is the preferred, shorter alias for configuration file options.
|
||||
- Example: `trtllm-serve --model <model_path> --config config.yaml` (preferred)
|
||||
- Avoid: `trtllm-serve --model <model_path> --extra_llm_api_options config.yaml`
|
||||
|
||||
## NVIDIA Copyright
|
||||
|
||||
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
|
||||
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
|
||||
```cpp
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
@ -10,7 +10,7 @@ state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.<
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://pytorch.org)
|
||||
[](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
|
||||
[](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
|
||||
[](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
|
||||
|
||||
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html) | [Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html) | [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) | [Documentation](https://nvidia.github.io/TensorRT-LLM/) | [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
|
||||
|
||||
@ -68,6 +68,7 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
|
||||
ON)
|
||||
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
|
||||
"Using open sourced Cutlass AR gemm kernel" ON)
|
||||
option(SKIP_SOFTMAX_STAT "Enable Statistics of Skip-Softmax" OFF)
|
||||
|
||||
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
|
||||
|
||||
@ -360,6 +361,11 @@ else()
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
|
||||
endif()
|
||||
|
||||
if(SKIP_SOFTMAX_STAT)
|
||||
add_compile_definitions("SKIP_SOFTMAX_STAT")
|
||||
message(STATUS "SKIP_SOFTMAX_STAT is enabled")
|
||||
endif()
|
||||
|
||||
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
|
||||
# be found in
|
||||
# https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html#index-mcmodel_003dmedium-1
|
||||
|
||||
@ -380,6 +380,7 @@ public:
|
||||
, mBeamWidth(beamWidth)
|
||||
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
|
||||
, mNumFrontBlocksRemoved(0)
|
||||
, mCurrentPrepopulatedPromptLen(std::numeric_limits<SizeType32>::max())
|
||||
{
|
||||
auto const numWindowSizes = windowSizeToMetadata.size();
|
||||
mCacheBlockIds.reserve(numWindowSizes);
|
||||
@ -500,6 +501,20 @@ public:
|
||||
return mKvCacheRetentionConfig.getDirectory();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getCurrentPrepopulatedPromptLen() const
|
||||
{
|
||||
return mCurrentPrepopulatedPromptLen;
|
||||
}
|
||||
|
||||
void setCurrentPrepopulatedPromptLen(SizeType32 currentPrepopulatedPromptLen)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(currentPrepopulatedPromptLen <= mCurrentPrepopulatedPromptLen,
|
||||
"currentPrepopulatedPromptLen must be updated non-increasingly due to the "
|
||||
"assumption that smaller window sizes have shorter or equal"
|
||||
"currentPrepopulatedPromptLen in WindowSizeManager::loadOrAllocateBlocks.");
|
||||
mCurrentPrepopulatedPromptLen = currentPrepopulatedPromptLen;
|
||||
}
|
||||
|
||||
private:
|
||||
// Request id of the sequence
|
||||
LlmRequest::RequestIdType mRequestId;
|
||||
@ -517,6 +532,8 @@ private:
|
||||
SizeType32 mNumFrontBlocksRemoved;
|
||||
// Set of used blocks by the sequence
|
||||
std::set<KVCacheBlock::IdType> mUsedBlocks;
|
||||
// Current prepopulated prompt length
|
||||
SizeType32 mCurrentPrepopulatedPromptLen;
|
||||
};
|
||||
|
||||
// attach metadata to a pool pointer
|
||||
@ -631,7 +648,7 @@ public:
|
||||
|
||||
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
|
||||
|
||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
|
||||
|
||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||
@ -836,8 +853,8 @@ public:
|
||||
//! \param blockKeys Key of each block.
|
||||
//! \param blockIds Id of each block.
|
||||
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
|
||||
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
|
||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
||||
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
|
||||
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||
bool pinBlocks = false);
|
||||
|
||||
@ -869,8 +886,8 @@ public:
|
||||
|
||||
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
|
||||
|
||||
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
|
||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
||||
//! \brief Unpin blocks by block ids directly
|
||||
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||
|
||||
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
|
||||
{
|
||||
@ -1086,7 +1103,7 @@ public:
|
||||
std::optional<KVCacheBlock::IdType> releaseBlocks(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||
|
||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||
|
||||
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
|
||||
@ -1095,7 +1112,7 @@ public:
|
||||
/// @param sequence The generation request whose blocks should be pinned.
|
||||
void pinBlocks(GenerationRequest& sequence);
|
||||
|
||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
||||
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||
|
||||
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
|
||||
|
||||
@ -1116,7 +1133,7 @@ public:
|
||||
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
|
||||
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
|
||||
|
||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
||||
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||
SizeType32 windowSize, bool pinBlocks = false)
|
||||
{
|
||||
@ -1567,7 +1584,7 @@ public:
|
||||
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
|
||||
|
||||
/// \brief Store blocks for reuse for a given request id
|
||||
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
|
||||
= 0;
|
||||
|
||||
@ -1661,7 +1678,7 @@ public:
|
||||
BlockKey const& blockKey, SizeType32 windowSize)
|
||||
= 0;
|
||||
|
||||
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
|
||||
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
|
||||
};
|
||||
|
||||
class KVCacheManager : public BaseKVCacheManager
|
||||
@ -1922,7 +1939,7 @@ public:
|
||||
//! \brief Store newest blocks for reuse
|
||||
void storeNewBlock(LlmRequest const& llmRequest) override;
|
||||
|
||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
|
||||
|
||||
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
||||
@ -1943,7 +1960,7 @@ public:
|
||||
|
||||
void pinBlocks(LlmRequest::RequestIdType requestId) override;
|
||||
|
||||
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
|
||||
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
|
||||
|
||||
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
|
||||
|
||||
|
||||
@ -1667,6 +1667,12 @@ public:
|
||||
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
|
||||
{
|
||||
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
|
||||
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isTimedOut() const
|
||||
{
|
||||
if (!mAllottedTimeMs.has_value())
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/serialization.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::executor::kv_cache
|
||||
@ -27,8 +28,9 @@ class CommState;
|
||||
struct DataContext
|
||||
{
|
||||
public:
|
||||
explicit DataContext(int tag)
|
||||
explicit DataContext(int tag, std::atomic<bool> const& transferTerminate = sDefaultTransferTerminate)
|
||||
: mTag{tag}
|
||||
, mTransferTerminate(transferTerminate)
|
||||
{
|
||||
}
|
||||
|
||||
@ -37,8 +39,15 @@ public:
|
||||
return mTag;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::atomic<bool> const& getTransferTerminate() const noexcept
|
||||
{
|
||||
return mTransferTerminate;
|
||||
}
|
||||
|
||||
private:
|
||||
inline static std::atomic<bool> sDefaultTransferTerminate{false};
|
||||
int const mTag;
|
||||
std::atomic<bool> const& mTransferTerminate;
|
||||
};
|
||||
|
||||
class Connection
|
||||
|
||||
@ -1468,7 +1468,8 @@ public:
|
||||
DEFAULT = 0,
|
||||
MPI = 1,
|
||||
UCX = 2,
|
||||
NIXL = 3
|
||||
NIXL = 3,
|
||||
MOONCAKE = 4
|
||||
};
|
||||
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
|
||||
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
|
||||
|
||||
@ -274,13 +274,20 @@ private:
|
||||
std::optional<SyncMessage> mSyncMessage;
|
||||
};
|
||||
|
||||
enum class TransferState : uint8_t
|
||||
{
|
||||
kIN_PROGRESS,
|
||||
kSUCCESS,
|
||||
kFAILURE,
|
||||
};
|
||||
|
||||
// Data structure for checking the status of active transfer operations.
|
||||
class TransferStatus
|
||||
{
|
||||
public:
|
||||
virtual ~TransferStatus() = default;
|
||||
[[nodiscard]] virtual bool isCompleted() const = 0;
|
||||
virtual void wait() const = 0;
|
||||
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;
|
||||
};
|
||||
|
||||
struct BaseAgentConfig
|
||||
@ -288,6 +295,8 @@ struct BaseAgentConfig
|
||||
std::string mName;
|
||||
bool useProgThread;
|
||||
bool multiThread;
|
||||
bool useListenThread;
|
||||
unsigned int numWorkers;
|
||||
};
|
||||
|
||||
class BaseTransferAgent
|
||||
@ -391,6 +400,14 @@ template <typename... Args>
|
||||
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
|
||||
return func(std::forward<Args>(args)...);
|
||||
}
|
||||
if (backend == "mooncake")
|
||||
{
|
||||
auto& loader = DynLibLoader::getInstance();
|
||||
using CreateMooncakeFuncType = std::unique_ptr<BaseTransferAgent> (*)(BaseAgentConfig const*);
|
||||
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
|
||||
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
|
||||
return func(std::forward<Args>(args)...);
|
||||
}
|
||||
TLLM_THROW("Unknown backend name.");
|
||||
}
|
||||
|
||||
|
||||
@ -104,12 +104,14 @@ public:
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getTensorParallelRank() const noexcept
|
||||
{
|
||||
return mRank % mTensorParallelism;
|
||||
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
|
||||
return (mRank % (mTensorParallelism * mContextParallelism)) / mContextParallelism;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getContextParallelRank() const noexcept
|
||||
{
|
||||
return (mRank % (mTensorParallelism * mContextParallelism)) / mTensorParallelism;
|
||||
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
|
||||
return mRank % mContextParallelism;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept
|
||||
|
||||
@ -69,6 +69,11 @@ PREPROCESSOR_FLAGS += -DUSE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE
|
||||
# Do we want to use half accumulation for flash attention
|
||||
PREPROCESSOR_FLAGS += -DHALF_ACCUMULATION_FOR_FLASH_ATTENTION
|
||||
|
||||
# Print the resulted sparsity given threshold in Skip-Softmax attention
|
||||
# Note: You only need to "python scripts/build_wheel.py -D SKIP_SOFTMAX_STAT=ON ..." to use it inside TRTLLM.
|
||||
# Turn this on manually only if you want to build&run the unittest (bin/fmha.exe) with SKIP_SOFTMAX_STAT.
|
||||
# PREPROCESSOR_FLAGS += -DSKIP_SOFTMAX_STAT
|
||||
|
||||
# Add FLAGS when generating cubins.
|
||||
ifdef GENERATE_CUBIN
|
||||
PREPROCESSOR_FLAGS += -DGENERATE_CUBIN
|
||||
|
||||
@ -154,7 +154,9 @@ spec_fields = (
|
||||
'head_size_v',
|
||||
'sage_block_sizes',
|
||||
'output_dtype',
|
||||
'is_mtp')
|
||||
'is_mtp',
|
||||
'enable_skip_softmax',
|
||||
)
|
||||
kernel_spec = namedtuple('kernel_spec', spec_fields)
|
||||
kernel_spec.__new__.__defaults__ = (
|
||||
1, # ctas_per_head
|
||||
@ -179,7 +181,9 @@ kernel_spec.__new__.__defaults__ = (
|
||||
0, # head size of V
|
||||
None, # sage_block_sizes
|
||||
None, # output_dtype, same as dtype by default.
|
||||
False) # use MTP or not
|
||||
False, # use MTP or not
|
||||
False, # enable skip softmax
|
||||
)
|
||||
|
||||
generate_cu_trtllm = os.environ.get('GENERATE_CU_TRTLLM',
|
||||
'False').lower() == 'true'
|
||||
@ -1435,6 +1439,7 @@ using Ktraits = {kernel_traits_header}
|
||||
USE_TMA_STORE,
|
||||
{enable_attn_logit_softcapping_flag},
|
||||
{return_softmax_stats_flag},
|
||||
{enable_skip_softmax_flag},
|
||||
{output_dtype_},
|
||||
{sage_block_size_q},
|
||||
{sage_block_size_k},
|
||||
@ -1458,6 +1463,7 @@ using Ktraits_causal = {kernel_traits_header}
|
||||
USE_TMA_STORE,
|
||||
{enable_attn_logit_softcapping_flag},
|
||||
{return_softmax_stats_flag},
|
||||
{enable_skip_softmax_flag},
|
||||
{output_dtype_}>;
|
||||
|
||||
using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
|
||||
@ -1478,6 +1484,7 @@ using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
|
||||
USE_TMA_STORE && false,
|
||||
{enable_attn_logit_softcapping_flag},
|
||||
{return_softmax_stats_flag},
|
||||
{enable_skip_softmax_flag},
|
||||
{output_dtype_}>;
|
||||
|
||||
using Ktraits_custom_mask = {kernel_traits_header}
|
||||
@ -1498,6 +1505,7 @@ using Ktraits_custom_mask = {kernel_traits_header}
|
||||
USE_TMA_STORE && false,
|
||||
{enable_attn_logit_softcapping_flag},
|
||||
{return_softmax_stats_flag},
|
||||
{enable_skip_softmax_flag},
|
||||
{output_dtype_}>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1835,6 +1843,8 @@ def encode_name(kernel_spec):
|
||||
|
||||
if kernel_spec.enable_attn_logit_softcapping:
|
||||
feature_tags += '_softcapping'
|
||||
if kernel_spec.enable_skip_softmax:
|
||||
feature_tags += '_skipSoftmax'
|
||||
if kernel_spec.sage_block_sizes:
|
||||
feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}"
|
||||
if kernel_spec.output_dtype:
|
||||
@ -2131,6 +2141,8 @@ def get_kernel_code(kspec, kname, lname):
|
||||
|
||||
return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats]
|
||||
|
||||
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
|
||||
|
||||
# needed by warpspec kernels.
|
||||
fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
|
||||
kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
|
||||
@ -2331,6 +2343,8 @@ def get_api_code(specs_names):
|
||||
f'&& sage_block_size_k == {sage_block_size_k} ' \
|
||||
f'&& sage_block_size_v == {sage_block_size_v} '
|
||||
|
||||
il_check += '&& enable_skip_softmax ' if kspec.enable_skip_softmax else '&& !enable_skip_softmax '
|
||||
|
||||
il_check += '&& params.use_int8_scale_max ' if kspec.has_scale_max else '&& !params.use_int8_scale_max '
|
||||
|
||||
slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0
|
||||
@ -2607,6 +2621,7 @@ const bool warp_specialization = launch_params.warp_specialization
|
||||
const bool use_tma = launch_params.use_tma;
|
||||
const bool use_flash_attention = launch_params.flash_attention;
|
||||
const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping;
|
||||
const bool enable_skip_softmax = launch_params.enable_skip_softmax;
|
||||
const int attention_input_layout = static_cast<int>(launch_params.attention_input_layout);
|
||||
// tiled variant uses ldgsts
|
||||
const bool use_tiled = launch_params.use_granular_tiling;
|
||||
@ -2785,6 +2800,8 @@ def get_kernel_traits_code(specs_names):
|
||||
enable_attn_logit_softcapping_flag = pythonBoolean2cpp[
|
||||
kspec.enable_attn_logit_softcapping]
|
||||
|
||||
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
|
||||
|
||||
tmp = dict(locals(), **kspec._asdict())
|
||||
|
||||
if effective_sm < 90:
|
||||
@ -2903,7 +2920,8 @@ def get_kernel_traits_code(specs_names):
|
||||
{input_layout_flag},
|
||||
__use_tma_store__ /* USE_TMA_STORE */,
|
||||
{enable_attn_logit_softcapping_flag},
|
||||
{return_softmax_stats_flag}>;
|
||||
{return_softmax_stats_flag},
|
||||
{enable_skip_softmax_flag}>;
|
||||
|
||||
printf("%s %d %d %s %d %d\\n",
|
||||
\"{kname}\",
|
||||
@ -3062,9 +3080,16 @@ def get_kernel_traits_code(specs_names):
|
||||
# For now:
|
||||
# 1. Hopper head_size 128 kernel uses cubins for performance regressions.
|
||||
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
|
||||
# 3. For skip-softmax attention feature, we force not to use cubins.
|
||||
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
|
||||
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
|
||||
def use_cubin_header(sm, head_size, dtype, output_dtype=None):
|
||||
def use_cubin_header(sm,
|
||||
head_size,
|
||||
dtype,
|
||||
output_dtype=None,
|
||||
enable_skip_softmax=False):
|
||||
if enable_skip_softmax:
|
||||
return False
|
||||
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
|
||||
return False
|
||||
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
|
||||
@ -3079,7 +3104,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
launchers_dict = {}
|
||||
for kspec, fname, lname, kname in specs_names:
|
||||
if generate_cu_trtllm and not use_cubin_header(
|
||||
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype):
|
||||
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype,
|
||||
kspec.enable_skip_softmax):
|
||||
continue
|
||||
name = fname.replace('.', '_')
|
||||
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
|
||||
@ -3111,8 +3137,9 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
'q_kv_', '').replace('q_paged_kv_', '').replace(
|
||||
'q_k_v_', '').replace('ws_', '').replace(
|
||||
'softcapping_',
|
||||
'').replace('sage_',
|
||||
'').replace('output_', ''))
|
||||
'').replace('sage_', '').replace(
|
||||
'skipSoftmax_',
|
||||
'').replace('output_', ''))
|
||||
flash_attention = 'flash_attention' in kname
|
||||
warp_specialization = 'tma_ws' in kname
|
||||
toks = tname.split('_')
|
||||
@ -3209,6 +3236,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
return_softmax_stats_flag = pythonBoolean2cpp[sm != '90' or (
|
||||
sm == '90' and '_softmax' in kname)]
|
||||
|
||||
enable_skip_softmax_flag = pythonBoolean2cpp['_skipSoftmax' in kname]
|
||||
|
||||
# meta_unroll_step
|
||||
meta_unroll_step = unroll_step if ('_nl' in kname
|
||||
or '_ws' in kname) else '0'
|
||||
@ -3235,7 +3264,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
|
||||
def get_lname_from_kname(kname: str) -> str:
|
||||
if use_cubin_header(int(sm), int(head_size), prec.lower(),
|
||||
output_prec.lower()):
|
||||
output_prec.lower(),
|
||||
enable_skip_softmax_flag):
|
||||
return 'nullptr'
|
||||
lname = kname.replace('_kernel', '')
|
||||
mask_types = [
|
||||
@ -3253,15 +3283,15 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
||||
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
||||
'''.format(**locals()) if use_cubin_header(int(sm),
|
||||
int(head_size), prec.lower(),
|
||||
output_prec.lower()) else '''\
|
||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
|
||||
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
|
||||
prec.lower(), output_prec.lower(),
|
||||
enable_skip_softmax_flag) else '''\
|
||||
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
||||
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
|
||||
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
|
||||
'''.format(**locals())
|
||||
else:
|
||||
code = '''\
|
||||
@ -3269,7 +3299,7 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
||||
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}}}\
|
||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}}}\
|
||||
'''.format(**locals())
|
||||
if sm in metadata_v2_dict:
|
||||
metadata_v2_dict[sm].append(code)
|
||||
@ -3377,7 +3407,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
bool mAlibiSupported;
|
||||
bool mTiled;
|
||||
bool mEnableAttnLogitSoftcapping;
|
||||
bool mReturnSoftmaxStats;{launcher_line}
|
||||
bool mReturnSoftmaxStats;
|
||||
bool mEnableSkipSoftmax;{launcher_line}
|
||||
}} sMhaKernelMetaInfosV2[] = {{
|
||||
{metadata_v2}
|
||||
}};
|
||||
@ -3438,6 +3469,7 @@ static const struct TestMetaV2
|
||||
bool mTiled;
|
||||
bool mEnableAttnLogitSoftcapping;
|
||||
bool mReturnSoftmaxStats;
|
||||
bool mEnableSkipSoftmax;
|
||||
}} metaV2[] = {{
|
||||
{metadata_v2}
|
||||
}};
|
||||
@ -3484,7 +3516,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
bool mAlibiSupported;
|
||||
bool mTiled;
|
||||
bool mEnableAttnLogitSoftcapping;
|
||||
bool mReturnSoftmaxStats;{launcher_line}
|
||||
bool mReturnSoftmaxStats;
|
||||
bool mEnableSkipSoftmax;{launcher_line}
|
||||
}};
|
||||
|
||||
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[];
|
||||
@ -3580,7 +3613,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
bool mAlibiSupported;
|
||||
bool mTiled;
|
||||
bool mEnableAttnLogitSoftcapping;
|
||||
bool mReturnSoftmaxStats;{launcher_line}
|
||||
bool mReturnSoftmaxStats;
|
||||
bool mEnableSkipSoftmax;{launcher_line}
|
||||
}};
|
||||
|
||||
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[] = {{
|
||||
@ -3637,7 +3671,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_
|
||||
return '\n'.join(lines)
|
||||
|
||||
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
|
||||
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
|
||||
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, false, nullptr},'
|
||||
result = modify_kernel_line(result, target, new_line)
|
||||
|
||||
# make sure only one empty line at the end
|
||||
@ -3801,7 +3835,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
|
||||
|
||||
|
||||
# Note this will be used in TRT-LLM.
|
||||
def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
||||
def enumerate_hgmma_flash_warpspec_kernels(specs,
|
||||
sm=90,
|
||||
dtype='fp16',
|
||||
enable_skip_softmax=False):
|
||||
|
||||
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
||||
|
||||
@ -3851,7 +3888,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
input_layout=input_layout,
|
||||
enable_skip_softmax=enable_skip_softmax))
|
||||
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
@ -3883,7 +3921,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
input_layout=input_layout,
|
||||
enable_skip_softmax=enable_skip_softmax))
|
||||
|
||||
specs.append(
|
||||
kernel_spec(
|
||||
@ -3915,7 +3954,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||
return_softmax_stats=return_softmax,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout))
|
||||
input_layout=input_layout,
|
||||
enable_skip_softmax=enable_skip_softmax))
|
||||
'''
|
||||
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
|
||||
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
|
||||
@ -3967,7 +4007,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
||||
sm=90,
|
||||
dtype='e4m3',
|
||||
sage_block_sizes=None,
|
||||
output_dtype=None):
|
||||
output_dtype=None,
|
||||
enable_skip_softmax=False):
|
||||
|
||||
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
||||
|
||||
@ -4021,7 +4062,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout,
|
||||
sage_block_sizes=sage_block_sizes,
|
||||
output_dtype=output_dtype))
|
||||
output_dtype=output_dtype,
|
||||
enable_skip_softmax=enable_skip_softmax))
|
||||
|
||||
# 64 < D <=128: KV_STEP = 128
|
||||
specs.append(
|
||||
@ -4056,7 +4098,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout,
|
||||
sage_block_sizes=sage_block_sizes,
|
||||
output_dtype=output_dtype))
|
||||
output_dtype=output_dtype,
|
||||
enable_skip_softmax=enable_skip_softmax))
|
||||
|
||||
# 128 < D <=256: KV_STEP = 128
|
||||
specs.append(
|
||||
@ -4092,7 +4135,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
||||
scheduling_mode=scheduling_mode,
|
||||
input_layout=input_layout,
|
||||
sage_block_sizes=sage_block_sizes,
|
||||
output_dtype=output_dtype))
|
||||
output_dtype=output_dtype,
|
||||
enable_skip_softmax=enable_skip_softmax))
|
||||
|
||||
if not skip_mla_combination:
|
||||
# context MLA (192x128)
|
||||
@ -6374,13 +6418,21 @@ def enumerate_kernels():
|
||||
enumerate_igmma_kernels(specs, sm=90)
|
||||
enumerate_qgmma_kernels(specs, sm=90)
|
||||
# need to add bf16 kernels if needed
|
||||
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16')
|
||||
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='bf16')
|
||||
enumerate_qgmma_flash_warpspec_kernels(specs, sm=90, dtype='e4m3')
|
||||
enumerate_qgmma_flash_warpspec_kernels(specs,
|
||||
sm=90,
|
||||
dtype='e4m3',
|
||||
output_dtype="bf16")
|
||||
for enable_skip_softmax in [False, True]:
|
||||
if enable_skip_softmax and 'DISABLE_SKIP_SOFTMAX' in os.environ:
|
||||
continue
|
||||
enumerate_hgmma_flash_warpspec_kernels(
|
||||
specs, sm=90, dtype='fp16', enable_skip_softmax=enable_skip_softmax)
|
||||
enumerate_hgmma_flash_warpspec_kernels(
|
||||
specs, sm=90, dtype='bf16', enable_skip_softmax=enable_skip_softmax)
|
||||
enumerate_qgmma_flash_warpspec_kernels(
|
||||
specs, sm=90, dtype='e4m3', enable_skip_softmax=enable_skip_softmax)
|
||||
enumerate_qgmma_flash_warpspec_kernels(
|
||||
specs,
|
||||
sm=90,
|
||||
dtype='e4m3',
|
||||
output_dtype="bf16",
|
||||
enable_skip_softmax=enable_skip_softmax)
|
||||
|
||||
# For now SageAttention only needs BF16
|
||||
# block_size_q should be divisible by 64
|
||||
|
||||
@ -256,7 +256,8 @@ struct Compute
|
||||
actual_kv_seqlen, alibi_head_scale, \
|
||||
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
|
||||
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
|
||||
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, kv_step_idx == kv_idx_end - 1);
|
||||
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
|
||||
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -360,6 +361,12 @@ struct Compute
|
||||
// Contiguous QKV FMHA assumes q, and kv have the same sequence length.
|
||||
int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
|
||||
|
||||
// Update threshold of Skip-Softmax
|
||||
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
|
||||
{
|
||||
softmax.skip_softmax_threshold = params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
|
||||
}
|
||||
|
||||
// Calculate the alibi head_scaling_factor.
|
||||
float alibi_head_scale
|
||||
= APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(head_info.bidh, params.alibi_params) : 0.f;
|
||||
@ -513,6 +520,13 @@ struct Compute
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
if (tidx == 0)
|
||||
{
|
||||
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
|
||||
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -522,8 +536,15 @@ struct Compute
|
||||
Compute_tile_o& ctile_o, float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M],
|
||||
int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
|
||||
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, Circular_buffer_kv_reader& cbr_v,
|
||||
OrderedMutexAccessor& mutex, bool complete = false)
|
||||
OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, bool complete = false)
|
||||
{
|
||||
|
||||
// Skip-softmax vote initialization
|
||||
if (tidx == 0)
|
||||
{
|
||||
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before voting.
|
||||
*skip_softmax_vote = 1;
|
||||
}
|
||||
// load the scales of K/V from global memory
|
||||
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
|
||||
if constexpr (block_size > 0) \
|
||||
@ -557,6 +578,10 @@ struct Compute
|
||||
// Ctile_p is only used once by each n step.
|
||||
ctile_p.clear();
|
||||
|
||||
// If skip_softmax is enabled, make sure there is no racing between the initialization and writing of
|
||||
// skip_softmax_vote.
|
||||
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);
|
||||
|
||||
// BMM1 (Q x K').
|
||||
warpgroup_arrive();
|
||||
|
||||
@ -626,8 +651,22 @@ struct Compute
|
||||
softmax.apply_alibi_and_mask<APPLY_MASK>(
|
||||
ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset);
|
||||
|
||||
// Softmax Exp, max/sum, and update scales.
|
||||
softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum);
|
||||
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
|
||||
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote))
|
||||
{
|
||||
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
|
||||
{
|
||||
// Notify another warpgroup to execute QGMMA.
|
||||
mutex.named_bar_arrive();
|
||||
}
|
||||
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
|
||||
int ready2 = cbr_v.peek();
|
||||
if (!ready2)
|
||||
{
|
||||
cbr_v.wait();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// experiments show that here is the best place to load scales of V
|
||||
float scales_v[SAGE_BLOCKS_PER_STEP_V];
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fmha/hopper/arrive_wait.h"
|
||||
|
||||
#include <fmha/softmax.h>
|
||||
#include <fmha/traits.h>
|
||||
#include <fmha/utils.h>
|
||||
@ -104,6 +106,12 @@ struct Softmax_base
|
||||
CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK
|
||||
};
|
||||
|
||||
// There are 2 warpgroups so 0x3 and 0x4 are used
|
||||
enum
|
||||
{
|
||||
SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID
|
||||
};
|
||||
|
||||
// Ctor.
|
||||
template <typename Params>
|
||||
inline __device__ Softmax_base(Params params, int tidx)
|
||||
@ -114,6 +122,11 @@ struct Softmax_base
|
||||
, log2_chunked_attention_size_(params.log2_chunked_attention_size)
|
||||
, packed_mask_ptr_{reinterpret_cast<uint32_t*>(params.packed_mask_ptr)}
|
||||
, params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes}
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
, total_blocks(0)
|
||||
, skipped_blocks(0)
|
||||
#endif
|
||||
, skip_softmax_threshold(0)
|
||||
{
|
||||
|
||||
int warp = tidx / 32;
|
||||
@ -330,24 +343,22 @@ struct Softmax_base
|
||||
}
|
||||
|
||||
// Calculate max/sum, and update flash-attention scales.
|
||||
// Returns false if skipped due to skip-softmax attention feature.
|
||||
template <bool IS_FIRST_COL>
|
||||
inline __device__ void compute_and_update_scale(
|
||||
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
|
||||
inline __device__ bool compute_and_update_scale(
|
||||
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
|
||||
{
|
||||
float const scale = reinterpret_cast<float const&>(scale_bmm1_);
|
||||
|
||||
// whether this warpgroup skips the softmax
|
||||
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
|
||||
bool skip = may_skip;
|
||||
|
||||
// Row-wise max of current tile.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
|
||||
{
|
||||
if (IS_FIRST_COL)
|
||||
{
|
||||
local_max_[mi] = elt_[mi][0];
|
||||
}
|
||||
else
|
||||
{
|
||||
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
|
||||
}
|
||||
local_max_[mi] = elt_[mi][0];
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
|
||||
{
|
||||
@ -355,6 +366,56 @@ struct Softmax_base
|
||||
}
|
||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
|
||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
|
||||
|
||||
if constexpr (may_skip)
|
||||
{
|
||||
// AND(&) the CORES_M results, then `skip` means whether to skip
|
||||
// the CORES_M(=2) rows
|
||||
if constexpr (!EXP2F_OPTIMIZATION)
|
||||
{
|
||||
skip &= expf(local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
|
||||
}
|
||||
else
|
||||
{
|
||||
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
|
||||
}
|
||||
}
|
||||
|
||||
if (!IS_FIRST_COL)
|
||||
{
|
||||
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
|
||||
{
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
total_blocks++;
|
||||
#endif
|
||||
if constexpr (may_skip)
|
||||
{
|
||||
|
||||
// AND(&) the results together in a warp, then `skip` means whether to skip
|
||||
// all the 16 rows managed by this warp.
|
||||
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
|
||||
// instead of 0xffffffff. But the perf is the same.
|
||||
skip = __all_sync(0xffffffff, skip);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
{
|
||||
// The leader of each warp votes.
|
||||
atomicAnd(skip_softmax_vote, uint32_t(skip));
|
||||
}
|
||||
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
|
||||
named_barrier_wait(SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
|
||||
skip = *((uint32_t volatile*) skip_softmax_vote);
|
||||
if (skip)
|
||||
{
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
skipped_blocks++;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax Exp.
|
||||
@ -436,6 +497,7 @@ struct Softmax_base
|
||||
global_max[mi] = max_new;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Update flash attention scales and pack elements for BMM2.
|
||||
@ -513,6 +575,13 @@ struct Softmax_base
|
||||
float correction_[Mma_tile_p::CORES_M];
|
||||
// The packed mask.
|
||||
uint4 packed_mask_;
|
||||
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
|
||||
float skip_softmax_threshold;
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
// Statistics of skip-softmax
|
||||
uint32_t total_blocks;
|
||||
uint32_t skipped_blocks;
|
||||
#endif
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -868,9 +937,10 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
}
|
||||
|
||||
// Calculate max/sum, and update flash-attention scales.
|
||||
// Returns false if skipped due to skip-softmax attention feature.
|
||||
template <bool IS_FIRST_COL>
|
||||
inline __device__ void compute_and_update_scale(
|
||||
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
|
||||
inline __device__ bool compute_and_update_scale(
|
||||
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
|
||||
{
|
||||
float const scale = reinterpret_cast<float const&>(this->scale_bmm1_);
|
||||
float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_;
|
||||
@ -878,18 +948,15 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
float(&correction_)[Mma_tile_p::CORES_M] = this->correction_;
|
||||
float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_;
|
||||
|
||||
// whether this warpgroup skips the softmax
|
||||
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
|
||||
bool skip = may_skip;
|
||||
|
||||
// Row-wise max of current tile.
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
|
||||
{
|
||||
if (IS_FIRST_COL)
|
||||
{
|
||||
local_max_[mi] = elt_[mi][0];
|
||||
}
|
||||
else
|
||||
{
|
||||
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
|
||||
}
|
||||
local_max_[mi] = elt_[mi][0];
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
|
||||
{
|
||||
@ -897,6 +964,56 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
}
|
||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
|
||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
|
||||
// AND(&) the CORES_M results, then `skip` means whether to skip
|
||||
// the CORES_M(=2) rows
|
||||
if constexpr (may_skip)
|
||||
{
|
||||
// AND(&) the CORES_M results, then `skip` means whether to skip
|
||||
// the CORES_M(=2) rows
|
||||
if constexpr (!EXP2F_OPTIMIZATION)
|
||||
{
|
||||
skip &= expf(local_max_[mi] - global_max[mi]) < this->skip_softmax_threshold;
|
||||
}
|
||||
else
|
||||
{
|
||||
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < this->skip_softmax_threshold;
|
||||
}
|
||||
}
|
||||
if (!IS_FIRST_COL)
|
||||
{
|
||||
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
|
||||
{
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
this->total_blocks++;
|
||||
#endif
|
||||
|
||||
if constexpr (may_skip)
|
||||
{
|
||||
// AND(&) the results together in a warp, then `skip` means whether to skip
|
||||
// all the 16 rows managed by this warp.
|
||||
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
|
||||
// instead of 0xffffffff. But the perf is the same.
|
||||
skip = __all_sync(0xffffffff, skip);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
{
|
||||
// The leader of each warp votes.
|
||||
atomicAnd(skip_softmax_vote, uint32_t(skip));
|
||||
}
|
||||
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
|
||||
named_barrier_wait(Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
|
||||
skip = *((uint32_t volatile*) skip_softmax_vote);
|
||||
if (skip)
|
||||
{
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
this->skipped_blocks++;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax Exp.
|
||||
@ -987,6 +1104,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
||||
global_max[mi] = max_new;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Update flash attention scales and pack elements for BMM2.
|
||||
|
||||
@ -71,6 +71,8 @@ template <
|
||||
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
|
||||
// Save softmax stats ?
|
||||
bool RETURN_SOFTMAX_STATS_ = false,
|
||||
// Enable skip softmax attention feature
|
||||
bool ENABLE_SKIP_SOFTMAX_ = false,
|
||||
// The output type (only used by fp8 kernels).
|
||||
typename OutputType = typename Instruction_traits<STEP_Q_, STEP_KV_, 0, false, false>::A_type,
|
||||
// The sage attention block size for Q, K and V
|
||||
@ -290,6 +292,12 @@ struct Kernel_traits
|
||||
USE_CUSTOM_MASK = ATTENTION_MASK_TYPE_ == 3
|
||||
};
|
||||
|
||||
// Are we enabling skip softmax attention feature?
|
||||
enum
|
||||
{
|
||||
ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_
|
||||
};
|
||||
|
||||
static_assert(!USE_CUSTOM_MASK || STEP_KV == 64 || STEP_KV == 128 || STEP_KV == 256, "Not implemented!");
|
||||
|
||||
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
|
||||
@ -384,6 +392,8 @@ struct Kernel_traits
|
||||
// Named barrier ids
|
||||
static constexpr int DMA_SYNC_BARRIER_ID = 0x1;
|
||||
static constexpr int MMA_SYNC_BARRIER_ID = 0x2;
|
||||
// There are 2 warpgroups so 0x3 and 0x4 are used for skip-softmax
|
||||
static constexpr int SKIP_SOFTMAX_BARRIER_ID = 0x3;
|
||||
|
||||
// How many threads get involved in the dma group.
|
||||
enum
|
||||
@ -518,6 +528,10 @@ struct Kernel_traits
|
||||
// Mutex
|
||||
OrderedMutex compute_mutex;
|
||||
|
||||
// 4 warps in a warpgroup vote to an atomic variable in shared memory
|
||||
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive KV_STEPS.
|
||||
uint32_t skip_softmax_votes[2][NUM_COMPUTE_GROUPS];
|
||||
|
||||
inline __device__ void init(int tid0)
|
||||
{
|
||||
|
||||
@ -580,6 +594,8 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
|
||||
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
|
||||
// Save softmax stats ?
|
||||
bool RETURN_SOFTMAX_STATS_ = false,
|
||||
// Enable skip softmax attention feature
|
||||
bool ENABLE_SKIP_SOFTMAX_ = false,
|
||||
// The output type (only used by fp8 kernels).
|
||||
typename OutputType = e4m3_t,
|
||||
// The sage attention block size for Q, K and V
|
||||
@ -588,14 +604,15 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
|
||||
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
||||
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
|
||||
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
|
||||
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>
|
||||
RETURN_SOFTMAX_STATS_, ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_,
|
||||
SAGE_BLOCK_SIZE_V_>
|
||||
{
|
||||
|
||||
// Base class.
|
||||
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
||||
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
|
||||
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_,
|
||||
OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
|
||||
ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
|
||||
|
||||
enum
|
||||
{
|
||||
@ -693,6 +710,10 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
|
||||
// Mutex
|
||||
OrderedMutex compute_mutex;
|
||||
|
||||
// 4 warps in a warpgroup vote to an atomic variable in shared memory
|
||||
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive STEP_KVs.
|
||||
uint32_t skip_softmax_votes[2][Base::NUM_COMPUTE_GROUPS];
|
||||
|
||||
inline __device__ void init(int tid0)
|
||||
{
|
||||
|
||||
|
||||
@ -276,7 +276,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
// scale factors
|
||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
||||
// flags
|
||||
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi)
|
||||
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi,
|
||||
float const skip_softmax_threshold_scale_factor)
|
||||
{
|
||||
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@ -421,6 +422,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
||||
params.enable_i2f_trick
|
||||
= -double(1 << 22) * double(scale_bmm2) <= -128.f && double(1 << 22) * double(scale_bmm2) >= 127.f;
|
||||
}
|
||||
|
||||
// Skip-softmax attention
|
||||
params.skip_softmax_threshold_scale_factor = skip_softmax_threshold_scale_factor;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -429,7 +433,7 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
|
||||
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
|
||||
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
|
||||
bool const force_non_flash_attention, bool const force_non_warp_specialization,
|
||||
bool const force_non_granular_tiling, bool const force_fp32_acc,
|
||||
bool const force_non_granular_tiling, bool const force_fp32_acc, float const skip_softmax_threshold_scale_factor,
|
||||
// device props
|
||||
const cudaDeviceProp props)
|
||||
{
|
||||
@ -470,6 +474,9 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
|
||||
"are not supported on Ada currently.\n");
|
||||
launch_params.use_granular_tiling = false;
|
||||
}
|
||||
|
||||
// Enable skip softmax attention or not.
|
||||
launch_params.enable_skip_softmax = skip_softmax_threshold_scale_factor > 0.f;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -589,6 +596,9 @@ int main(int argc, char** argv)
|
||||
// Use attention sinks (added to the denominator of softmax)
|
||||
bool use_attention_sinks = false;
|
||||
|
||||
// Skip-softmax attention
|
||||
float skip_softmax_threshold_scale_factor = 0;
|
||||
|
||||
// Read the parameters from the command-line.
|
||||
for (int ii = 1; ii < argc; ++ii)
|
||||
{
|
||||
@ -885,6 +895,10 @@ int main(int argc, char** argv)
|
||||
{
|
||||
use_attention_sinks = true;
|
||||
}
|
||||
else if (!strcmp(argv[ii], "-skip-softmax-threshold-scale-factor") && ++ii < argc)
|
||||
{
|
||||
skip_softmax_threshold_scale_factor = strtof(argv[ii], nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
|
||||
@ -1057,7 +1071,7 @@ int main(int argc, char** argv)
|
||||
Launch_params launch_params;
|
||||
determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, interleaved,
|
||||
ignore_b1opt, force_unroll, use_tma, force_non_flash_attention, force_non_warp_specialization,
|
||||
force_non_granular_tiling, force_fp32_acc, props);
|
||||
force_non_granular_tiling, force_fp32_acc, skip_softmax_threshold_scale_factor, props);
|
||||
|
||||
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
|
||||
const size_t qkv_size = s * b * h * (2 * d + dv);
|
||||
@ -1713,7 +1727,13 @@ int main(int argc, char** argv)
|
||||
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
|
||||
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
|
||||
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
||||
use_int8_scale_max, interleaved, is_s_padded, has_alibi);
|
||||
use_int8_scale_max, interleaved, is_s_padded, has_alibi, skip_softmax_threshold_scale_factor);
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
FMHA_CHECK_CUDA(cudaMalloc(¶ms_v2.skip_softmax_total_blocks, sizeof(uint32_t)));
|
||||
FMHA_CHECK_CUDA(cudaMalloc(¶ms_v2.skip_softmax_skipped_blocks, sizeof(uint32_t)));
|
||||
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_total_blocks, 0, sizeof(uint32_t)));
|
||||
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_skipped_blocks, 0, sizeof(uint32_t)));
|
||||
#endif
|
||||
|
||||
// total number of tokens is needed to set TMA desc on the host.
|
||||
launch_params.total_q_seqlen = q_seqlens[b];
|
||||
@ -2101,6 +2121,18 @@ int main(int argc, char** argv)
|
||||
non_fused_elapsed / fused_elapsed, total_flops / (fused_elapsed / float(runs) / 1e-9),
|
||||
total_bytes / (fused_elapsed / float(runs) / 1e-6));
|
||||
}
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
if (skip_softmax_threshold_scale_factor > 0)
|
||||
{
|
||||
uint32_t total_blocks, skipped_blocks;
|
||||
FMHA_CHECK_CUDA(
|
||||
cudaMemcpy(&total_blocks, params_v2.skip_softmax_total_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
|
||||
FMHA_CHECK_CUDA(cudaMemcpy(
|
||||
&skipped_blocks, params_v2.skip_softmax_skipped_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
|
||||
printf("Skip-Softmax .: %u / %u = %.2f%%\n", skipped_blocks, total_blocks,
|
||||
total_blocks ? 100.f * skipped_blocks / total_blocks : 0.f);
|
||||
}
|
||||
#endif
|
||||
#if defined(DEBUG_HAS_PRINT_BUFFER)
|
||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32));
|
||||
|
||||
@ -2141,6 +2173,11 @@ int main(int argc, char** argv)
|
||||
FMHA_CHECK_CUDA(cudaFree(kv_cache_block_offsets_d));
|
||||
FMHA_CHECK_CUDA(cudaFree(contiguous_kv_d));
|
||||
FMHA_CHECK_CUDA(cudaFree(softmax_stats_d));
|
||||
FMHA_CHECK_CUDA(cudaFree(attention_sinks_d));
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_total_blocks));
|
||||
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_skipped_blocks));
|
||||
#endif
|
||||
|
||||
free(qkv_h);
|
||||
free(mask_h);
|
||||
|
||||
@ -283,6 +283,16 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
|
||||
float* scales;
|
||||
} q, k, v;
|
||||
} sage;
|
||||
|
||||
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
|
||||
// A positive value means skip-softmax is enabled.
|
||||
float skip_softmax_threshold_scale_factor = 0;
|
||||
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
// Statistics of skip-softmax, pointers of device memory for output
|
||||
uint32_t* skip_softmax_total_blocks;
|
||||
uint32_t* skip_softmax_skipped_blocks;
|
||||
#endif
|
||||
};
|
||||
|
||||
#endif
|
||||
@ -322,6 +332,8 @@ struct Fused_multihead_attention_launch_params
|
||||
// harward properties to determine how to launch blocks
|
||||
int multi_processor_count = 0;
|
||||
int device_l2_cache_size = 0;
|
||||
// skip softmax attention
|
||||
bool enable_skip_softmax = false;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -177,4 +177,13 @@ struct Fused_multihead_attention_params_v2
|
||||
float* scales;
|
||||
} q, k, v;
|
||||
} sage;
|
||||
|
||||
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
|
||||
// A positive value means skip-softmax is enabled.
|
||||
float skip_softmax_threshold_scale_factor = 0;
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
// Statistics of skip-softmax, pointers of device memory for output
|
||||
uint32_t* skip_softmax_total_blocks;
|
||||
uint32_t* skip_softmax_skipped_blocks;
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
|
||||
#define SLIDING_WINDOW 0
|
||||
#endif
|
||||
|
||||
#ifndef SKIP_SOFTMAX_ATTN
|
||||
#define SKIP_SOFTMAX_ATTN 0
|
||||
#endif
|
||||
|
||||
#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
|
||||
#endif
|
||||
|
||||
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
|
||||
#endif
|
||||
|
||||
// 0 - no PDL
|
||||
// 1 - naive PDL
|
||||
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
|
||||
|
||||
@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
|
||||
asm volatile("trap;\n");
|
||||
return 0;
|
||||
}();
|
||||
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
|
||||
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
|
||||
return MatDesc{
|
||||
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
|
||||
|
||||
@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
|
||||
#endif
|
||||
|
||||
#ifndef GENERATE_CUBIN
|
||||
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
|
||||
{
|
||||
if (!allowMultiBlockMode)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
return std::min<uint32_t>(
|
||||
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
|
||||
}
|
||||
|
||||
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
|
||||
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
||||
{
|
||||
@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
||||
|
||||
// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
|
||||
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
|
||||
{
|
||||
if (!allowMultiBlockMode)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
return std::min<uint32_t>(
|
||||
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
|
||||
}();
|
||||
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
|
||||
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
|
||||
#if SPEC_DEC
|
||||
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);
|
||||
|
||||
@ -90,6 +90,9 @@ struct BeamSearchParams
|
||||
// match trt-llm API.
|
||||
};
|
||||
|
||||
uint32_t computeNbSubSeqPerSeqMHA(
|
||||
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
|
||||
|
||||
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
||||
|
||||
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
|
||||
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
|
||||
|
||||
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
||||
|
||||
|
||||
@ -49,6 +49,10 @@ static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is to
|
||||
#define SWAP_AB (!SPEC_DEC)
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
static_assert(SWAP_AB && USE_PAGED_KV_CACHE && !SPEC_DEC && BEAM_WIDTH == 1, "SKIP_SOFTMAX_ATTN is not supported.");
|
||||
#endif
|
||||
|
||||
#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT)
|
||||
|
||||
inline constexpr bool swapAB = SWAP_AB;
|
||||
@ -138,26 +142,38 @@ using PaddedOutHead = PaddedInputHead;
|
||||
|
||||
struct alignas(128) SharedMem
|
||||
{
|
||||
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
|
||||
using KBuffer = Array2D<LdGrain, gemm0CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes)>;
|
||||
static constexpr uint32_t nbKBuf = 2;
|
||||
KBuffer k[nbKBuf]; // as is loaded from global mem.
|
||||
using XBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerXPart>, nbXParts>;
|
||||
static constexpr uint32_t nbXBuf
|
||||
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||
using VBuffer = Vec<Array2D<LdGrain, gemm1CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes),
|
||||
sizeof(XBuffer) % (cacheHeadPartBytes * 8) == 0>,
|
||||
cacheHeadNbParts>;
|
||||
#if !SWAP_AB
|
||||
using VTBuffer = Array2D<LdGrain, headElems, exactDiv(gemm1CtaTileNbTokens, cacheElemsPerGrain), true>;
|
||||
#endif
|
||||
static constexpr uint32_t nbVBuf = 2;
|
||||
#if CACHE_ELEM_ENUM == 0
|
||||
using OutSwizzleBuf = Array2D<LdGrain, ctaNbQHeads, grainsPerPaddedInputHead>;
|
||||
#elif CACHE_ELEM_ENUM == 2
|
||||
using OutSwizzleBuf = Array2D<Vec<Vec<InputElem, 4>, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>;
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
static constexpr uint32_t nbKBuf = 2;
|
||||
static constexpr uint32_t nbVBuf = 3; // @fixme: skip_softmax_attn: for skip softmax attn, an extra VBuffer is used
|
||||
static constexpr uint32_t nbXBuf
|
||||
= 3 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||
#else
|
||||
static constexpr uint32_t nbKBuf = 2;
|
||||
static constexpr uint32_t nbVBuf = 2;
|
||||
static constexpr uint32_t nbXBuf
|
||||
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||
#endif
|
||||
static_assert(nbXBuf == nbVBuf);
|
||||
|
||||
// note: buffers used for GMMA may have additional alignment requirements
|
||||
KBuffer k[nbKBuf]; // as is loaded from global mem.
|
||||
QBuffer q; // For gmma math. Conversion done if needed.
|
||||
|
||||
union ReusedXVOutSwizzleBuf
|
||||
{
|
||||
struct XV
|
||||
@ -196,9 +212,6 @@ struct alignas(128) SharedMem
|
||||
return reusedXVOutSwizzleBuf[i].outSwizzle;
|
||||
}
|
||||
|
||||
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
|
||||
QBuffer q; // For gmma math. Conversion done if needed.
|
||||
|
||||
// @fixme: move these into reusedXVOutSwizzleBuf
|
||||
#if SWAP_AB
|
||||
ShmQWiseVec xColMax[nbXBuf];
|
||||
@ -220,6 +233,11 @@ struct alignas(128) SharedMem
|
||||
Vec<KVCachePageIndex, nbPagesPerTile> pages[2]; // one for K and one for V
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
uint32_t skipSoftmaxVotesGemm0ToV[nbXBuf]; // guarded by skipSoftmaxXBar
|
||||
uint32_t skipSoftmaxVotesGemm0ToGemm1[nbXBuf]; // guarded by xBar
|
||||
#endif
|
||||
|
||||
// mem barriers
|
||||
|
||||
CtaBarrierPair qBar;
|
||||
@ -229,6 +247,9 @@ struct alignas(128) SharedMem
|
||||
CtaBarrierPair vtBar[nbVBuf];
|
||||
#endif
|
||||
CtaBarrierPair xBar[nbXBuf];
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
CtaBarrierPair skipSoftmaxXBar[nbXBuf]; // for V to wait for X to be ready
|
||||
#endif
|
||||
|
||||
// used internally in the gemm0 warp group
|
||||
// @fixme: use separate arrive and wait for all usage
|
||||
@ -425,8 +446,13 @@ __device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
||||
#endif
|
||||
|
||||
#if SWAP_AB
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src,
|
||||
float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip);
|
||||
#else
|
||||
__device__ RegColWiseVec computeWarpGrpColMax_sync(
|
||||
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
|
||||
#endif
|
||||
__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd);
|
||||
__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax);
|
||||
__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
|
||||
@ -675,6 +701,12 @@ CUBIN_EXPORT __global__
|
||||
#endif
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* __restrict__ const semaphores
|
||||
= nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
|
||||
@ -753,6 +785,10 @@ CUBIN_EXPORT __global__
|
||||
uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
|
||||
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
|
||||
assert(isMultiBlockMode == (nbSubSeq > 1));
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
bool const disableSkipForShortSeq = (cacheSeqLen < skipSoftmaxThresholdScaleFactor);
|
||||
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / cacheSeqLen;
|
||||
#endif
|
||||
if (idxSubSeq >= nbSubSeq)
|
||||
{
|
||||
return;
|
||||
@ -776,21 +812,34 @@ CUBIN_EXPORT __global__
|
||||
assert(dynamicSmemSize() >= sizeof(SharedMem));
|
||||
SharedMem& smem = *reinterpret_cast<SharedMem*>(&smemByteBuf[0]);
|
||||
|
||||
constexpr uint32_t nbBuffers = 2;
|
||||
static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf);
|
||||
if (wid < nbBuffers)
|
||||
constexpr uint32_t maxNbBuffers = (SharedMem::nbXBuf > SharedMem::nbVBuf) ? SharedMem::nbXBuf : SharedMem::nbVBuf;
|
||||
static_assert(
|
||||
maxNbBuffers >= SharedMem::nbKBuf && maxNbBuffers >= SharedMem::nbVBuf && maxNbBuffers >= SharedMem::nbXBuf);
|
||||
if (wid < maxNbBuffers)
|
||||
{
|
||||
if (warpElectSync())
|
||||
{
|
||||
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
|
||||
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
|
||||
#if !SWAP_AB
|
||||
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
|
||||
if (wid < SharedMem::nbKBuf)
|
||||
{
|
||||
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
|
||||
}
|
||||
if (wid < SharedMem::nbXBuf)
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
smem.skipSoftmaxXBar[wid].initialize(gemm0NbThrds + warp_size, gemm0NbThrds + warp_size);
|
||||
smem.vBar[wid].initialize(gemm1NbThrds + warp_size, gemm1NbThrds + warp_size);
|
||||
#else
|
||||
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
|
||||
#endif
|
||||
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
|
||||
|
||||
#if !SWAP_AB
|
||||
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
|
||||
#endif
|
||||
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (wid == nbBuffers)
|
||||
else if (wid == maxNbBuffers)
|
||||
{
|
||||
if (warpElectSync())
|
||||
{
|
||||
@ -819,6 +868,10 @@ CUBIN_EXPORT __global__
|
||||
SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen};
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t localSkippedBlockCount = 0;
|
||||
#endif
|
||||
|
||||
// QK gemm
|
||||
constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM);
|
||||
using Acc = GmmaAcc<gemm0CtaTileNbTokens, ctaNbQHeads>;
|
||||
@ -940,10 +993,39 @@ CUBIN_EXPORT __global__
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||
auto& xBar = smem.xBar[idxXBuf];
|
||||
// update colMax in shared mem and get a register copy
|
||||
#if SWAP_AB
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
|
||||
skipSoftmaxXBar.consumed.arrive_and_wait();
|
||||
|
||||
bool const maybeSkip = !disableSkipForShortSeq && idxIter != 0;
|
||||
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc,
|
||||
skipSoftmaxThreshold, &smem.skipSoftmaxVotesGemm0ToV[idxXBuf], maybeSkip);
|
||||
bool const shouldSkipSoftmaxAttn = static_cast<bool>(smem.skipSoftmaxVotesGemm0ToV[idxXBuf]);
|
||||
unused(skipSoftmaxXBar.produced.arrive());
|
||||
warpGrpOnlineSoftmax(acc, colMax);
|
||||
if (shouldSkipSoftmaxAttn)
|
||||
{
|
||||
xBar.consumed.arrive_and_wait();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 1U;
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
localSkippedBlockCount++;
|
||||
#endif
|
||||
}
|
||||
asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used
|
||||
unused(xBar.produced.arrive());
|
||||
continue;
|
||||
}
|
||||
#else
|
||||
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc);
|
||||
warpGrpOnlineSoftmax(acc, colMax);
|
||||
#endif
|
||||
#else
|
||||
RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc);
|
||||
warpGrpOnlineSoftmax(acc, rowMax);
|
||||
@ -959,8 +1041,6 @@ CUBIN_EXPORT __global__
|
||||
// map 1 to fp8_max before conversion to fp8
|
||||
acc = acc * kE4M3_MAX;
|
||||
|
||||
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||
auto& xBar = smem.xBar[idxXBuf];
|
||||
// @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM.
|
||||
#if SWAP_AB
|
||||
storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc);
|
||||
@ -989,13 +1069,25 @@ CUBIN_EXPORT __global__
|
||||
storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax);
|
||||
storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum);
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 0;
|
||||
}
|
||||
#endif
|
||||
__syncwarp();
|
||||
// the release semantics of arrive does not work for async consumers like gmma. additional fence is
|
||||
// needed.
|
||||
asm volatile("fence.proxy.async.shared::cta;\n");
|
||||
unused(xBar.produced.arrive());
|
||||
}
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
if (threadIdx.x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr)
|
||||
{
|
||||
atomicAdd(skippedBlockCount, localSkippedBlockCount);
|
||||
atomicAdd(totalBlockCount, nbIters);
|
||||
}
|
||||
#endif
|
||||
unused(smem.qBar.consumed.arrive());
|
||||
}
|
||||
else if (warpIdx.z == 1)
|
||||
@ -1043,216 +1135,231 @@ CUBIN_EXPORT __global__
|
||||
uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
||||
auto const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||
auto const idxXBuf = idxVBuf;
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
|
||||
auto const& vBuf = smem.vBuf(idxVBuf);
|
||||
#if !SWAP_AB
|
||||
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
|
||||
auto& vtBuf = smem.vtBuf(idxVBuf);
|
||||
vtBar.consumed.arrive_and_wait();
|
||||
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
|
||||
vBar.consumed.arrive();
|
||||
vtBar.produced.arrive();
|
||||
#endif
|
||||
auto& xBar = smem.xBar[idxXBuf];
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
auto const& vBuf = smem.vBuf(idxVBuf);
|
||||
xBar.produced.arrive_and_wait();
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf]; // guarded by xBar
|
||||
if (shouldSkipSoftmaxAttn)
|
||||
{
|
||||
vBar.produced.arrive_and_wait();
|
||||
}
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (!shouldSkipSoftmaxAttn) // skip XVGemm
|
||||
#endif
|
||||
{
|
||||
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
|
||||
#if !SWAP_AB
|
||||
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
|
||||
auto& vtBuf = smem.vtBuf(idxVBuf);
|
||||
vtBar.consumed.arrive_and_wait();
|
||||
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
|
||||
vBar.consumed.arrive();
|
||||
vtBar.produced.arrive();
|
||||
#endif
|
||||
#if !defined(NDEBUG) && DBG_PRINT
|
||||
#if SWAP_AB
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("colMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xColMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("colSum:\n");
|
||||
for (int n = 0; n < 4; n++)
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("colMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
|
||||
printf("%f, ", smem.xColMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("colSum:\n");
|
||||
for (int n = 0; n < 4; n++)
|
||||
{
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
printf("X:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
|
||||
{
|
||||
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
|
||||
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
|
||||
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
|
||||
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
|
||||
printf("%.2f, ", float(e));
|
||||
if (j % 16 == 15)
|
||||
{
|
||||
printf("| ");
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
#else
|
||||
if (blockIdx.y == 1 && threadIdx.x == 0)
|
||||
{
|
||||
printf("rowMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("rowSum:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowSum[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
printf("X:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
|
||||
{
|
||||
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
|
||||
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
|
||||
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
|
||||
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
|
||||
printf("%.2f, ", float(e));
|
||||
if (j % 16 == 15)
|
||||
{
|
||||
printf("| ");
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
#else
|
||||
if (blockIdx.y == 1 && threadIdx.x == 0)
|
||||
{
|
||||
printf("rowMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("rowSum:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowSum[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if SWAP_AB
|
||||
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
|
||||
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
|
||||
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
|
||||
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
|
||||
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
|
||||
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
|
||||
#else
|
||||
rescaleGemm1AccForNewRowMax_sync(
|
||||
warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
|
||||
rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf],
|
||||
smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
|
||||
#endif
|
||||
auto& xBuf = smem.xBuf(idxXBuf);
|
||||
auto& xBuf = smem.xBuf(idxXBuf);
|
||||
|
||||
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
|
||||
.raw();
|
||||
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
|
||||
.raw();
|
||||
#if CACHE_ELEM_ENUM == 0
|
||||
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
|
||||
.raw();
|
||||
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
|
||||
.raw();
|
||||
#endif
|
||||
#if SWAP_AB
|
||||
//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed.
|
||||
#pragma unroll
|
||||
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
|
||||
{
|
||||
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
|
||||
{
|
||||
#if CACHE_ELEM_ENUM == 2
|
||||
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
|
||||
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
|
||||
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
|
||||
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
|
||||
#if !defined(NDEBUG) && DBG_PRINT
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("fragA:\nidxInstK == %u\n", idxInstK);
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
for (int m = 0; m < 2; m++)
|
||||
{
|
||||
for (int w = 0; w < 4; w++)
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
if (warpRank == w)
|
||||
printf("fragA:\nidxInstK == %u\n", idxInstK);
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
for (int m = 0; m < 2; m++)
|
||||
{
|
||||
for (int w = 0; w < 4; w++)
|
||||
{
|
||||
if (laneId() == 0)
|
||||
if (warpRank == w)
|
||||
{
|
||||
printf(" warpRank = %u\n", warpRank);
|
||||
}
|
||||
__syncwarp();
|
||||
for (int a = 0; a < 2; a++)
|
||||
{
|
||||
for (int b = 0; b < 8; b++)
|
||||
if (laneId() == 0)
|
||||
{
|
||||
for (int c = 0; c < 2; c++)
|
||||
printf(" warpRank = %u\n", warpRank);
|
||||
}
|
||||
__syncwarp();
|
||||
for (int a = 0; a < 2; a++)
|
||||
{
|
||||
for (int b = 0; b < 8; b++)
|
||||
{
|
||||
for (int d = 0; d < 4; d++)
|
||||
for (int c = 0; c < 2; c++)
|
||||
{
|
||||
if (laneId() == b * 4 + d)
|
||||
for (int d = 0; d < 4; d++)
|
||||
{
|
||||
for (int e = 0; e < 4; e++)
|
||||
if (laneId() == b * 4 + d)
|
||||
{
|
||||
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
|
||||
fragA[m](0, c)(a, 0));
|
||||
printf("%.2f, ", float(elem4[e]));
|
||||
for (int e = 0; e < 4; e++)
|
||||
{
|
||||
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
|
||||
fragA[m](0, c)(a, 0));
|
||||
printf("%.2f, ", float(elem4[e]));
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (laneId() == 0)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (laneId() == 0)
|
||||
if (laneId() == 0 && a == 0)
|
||||
{
|
||||
printf("\n");
|
||||
printf("----------------------\n");
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (laneId() == 0 && a == 0)
|
||||
{
|
||||
printf("----------------------\n");
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
#if CACHE_ELEM_ENUM == 2
|
||||
gmma::fence();
|
||||
gmma::fence();
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
|
||||
{
|
||||
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
|
||||
{
|
||||
#if CACHE_ELEM_ENUM == 0
|
||||
auto const descV
|
||||
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
|
||||
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
descV, descX, true);
|
||||
auto const descV
|
||||
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
|
||||
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
descV, descX, true);
|
||||
#elif CACHE_ELEM_ENUM == 2
|
||||
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
|
||||
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
|
||||
#endif
|
||||
}
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||
// gmma.
|
||||
gmma::wait_group<0>();
|
||||
}
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||
// gmma.
|
||||
gmma::wait_group<0>();
|
||||
}
|
||||
#else
|
||||
auto const descVTBase = gmma::makeMatDesc(
|
||||
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
|
||||
.raw();
|
||||
vtBar.produced.arrive_and_wait();
|
||||
auto const descVTBase = gmma::makeMatDesc(
|
||||
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
|
||||
.raw();
|
||||
vtBar.produced.arrive_and_wait();
|
||||
// if (idxIter == 1 && threadIdx.x == 0) {
|
||||
// printf("vtBuf:\n");
|
||||
// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
|
||||
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
|
||||
{
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
auto const descVT = addAddr(
|
||||
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
|
||||
gmma::mma_async_shmA<MathElem, headElems>(
|
||||
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
|
||||
descVT, true);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
|
||||
{
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
auto const descVT = addAddr(
|
||||
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
|
||||
gmma::mma_async_shmA<MathElem, headElems>(
|
||||
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
|
||||
descVT, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma.
|
||||
gmma::wait_group<0>();
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||
// gmma.
|
||||
gmma::wait_group<0>();
|
||||
#endif
|
||||
}
|
||||
|
||||
if (idxIter == nbIters - 1)
|
||||
{
|
||||
// gmma::wait_group should have already synchronized threads, so this may be unnecessary.
|
||||
@ -1471,8 +1578,24 @@ CUBIN_EXPORT __global__
|
||||
tensorMap
|
||||
#endif
|
||||
};
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
for (auto& b : smem.skipSoftmaxXBar)
|
||||
{
|
||||
unused(b.consumed.arrive());
|
||||
}
|
||||
#endif
|
||||
for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++)
|
||||
{
|
||||
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
uint32_t idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
|
||||
skipSoftmaxXBar.produced.arrive_and_wait();
|
||||
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToV[idxXBuf];
|
||||
skipSoftmaxXBar.consumed.arrive();
|
||||
#endif
|
||||
|
||||
uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
||||
vTileLoader.loadPages(idxVTile);
|
||||
#if USE_INPUT_KV || ENABLE_PDL == 2
|
||||
@ -1506,8 +1629,20 @@ CUBIN_EXPORT __global__
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (shouldSkipSoftmaxAttn)
|
||||
{
|
||||
vBar.consumed.arrive_and_wait();
|
||||
// compared to non-skip softmax attn, we need to increase vBar.produced count to avoid race
|
||||
// condition where vBar.consumed is arrived again without wait without skip softmax attn, XVGemm
|
||||
// will wait for tx_count, so its progress won't go ahead of vload warp with skip softmax attn,
|
||||
// XVGemm WG may go ahead of vload warp, as previous vBar only have XVGemm WG threads and a tx_count
|
||||
// (now = 0). Then it may arrive vBar.consumed before it is arrive_and_wait-ed
|
||||
vBar.produced.arrive();
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
vBar.consumed.arrive_and_wait();
|
||||
if (warpElectSync())
|
||||
{
|
||||
@ -1517,6 +1652,9 @@ CUBIN_EXPORT __global__
|
||||
vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced);
|
||||
}
|
||||
}
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
vBar.produced.arrive();
|
||||
#endif
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
@ -1992,9 +2130,23 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
||||
#endif // SPEC_DEC
|
||||
|
||||
// smemColMax is persistent across multiple iterations
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax,
|
||||
Gemm0Acc const& src, float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip)
|
||||
#else
|
||||
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src)
|
||||
#endif
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
*smemSkipVote = maybeSkip ? 1U : 0U; // will sync before vote
|
||||
}
|
||||
float const lnThreshold
|
||||
= log(skipSoftmaxThreshold); // this can be -inf, but should be safe as we only use it for comparison
|
||||
#endif
|
||||
|
||||
auto colMax = RegColWiseVec::filled(Vec<float, 2>::filled(safeInitRowMax));
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < src.cols; n++)
|
||||
@ -2029,6 +2181,9 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
}
|
||||
|
||||
uint32_t const lane = laneId();
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
auto prevOrCurrentMax = RegColWiseVec();
|
||||
#if SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||
if (lane < 4)
|
||||
{
|
||||
#pragma unroll
|
||||
@ -2037,12 +2192,43 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++)
|
||||
{
|
||||
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||
prevOrCurrentMax[n][j] = smemColMax[8 * n + 2 * lane + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
warpGrpBar.arrive_and_wait();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
if (lane < 4)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < src.cols; n++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++)
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN && !SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||
// prevOrCurrentMax <= actual smemColMax (after updates from all 4 warps done), but always >=
|
||||
// smemColMax(Prev), the smemColMax value *before* this tile is computed.
|
||||
// When determine whether to skip, it is safe to use prevOrCurrentMax: 1) all 4 warps' localmax <
|
||||
// smemColMax(Prev), then prevOrCurrentMax == smemColMax(Prev), result not affected; 2) if some localmax
|
||||
// > smemColMax(Prev), prevOrCurrentMax > smemColMax(Prev), some warps may incorrectly vote skip, but
|
||||
// at least one warp whose localColMax is larger will not skip, then the tile is not skipped.
|
||||
// This reduces some sync and check, but has issue when threshold > 1.
|
||||
prevOrCurrentMax[n][j] = atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||
#else
|
||||
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
warpGrpBar.arrive_and_wait();
|
||||
|
||||
uint32_t const idxInQuad = lane % 4;
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
bool localShouldSkip = true;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < src.cols; n++)
|
||||
@ -2050,10 +2236,21 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (lane < 4 && 8 * n + 2 * idxInQuad + j < headGrpSize)
|
||||
{
|
||||
localShouldSkip &= (colMax[n][j] - prevOrCurrentMax[n][j]) < lnThreshold;
|
||||
}
|
||||
#endif
|
||||
assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]);
|
||||
colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j];
|
||||
}
|
||||
}
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
atomicAnd(smemSkipVote, static_cast<uint32_t>(localShouldSkip)); // this will be translated to redux and voteu
|
||||
#endif
|
||||
|
||||
warpGrpBar.arrive_and_wait();
|
||||
return colMax;
|
||||
}
|
||||
@ -2199,7 +2396,7 @@ __device__ inline void storeGemm0AccToShm(
|
||||
uint32_t const idxOctInsideHalf = idxInHalf / 8;
|
||||
uint32_t const idxRowInsideOct = lane % 8;
|
||||
uint32_t const warpBaseC = 16 * warpRank;
|
||||
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair<uint32_t, uint32_t>
|
||||
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> mha::pair<uint32_t, uint32_t>
|
||||
{
|
||||
uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols;
|
||||
uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols;
|
||||
@ -3231,6 +3428,24 @@ __device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst,
|
||||
}
|
||||
|
||||
#ifndef GENERATE_CUBIN
|
||||
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
|
||||
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
|
||||
{
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
float const factor = 0.25f;
|
||||
return mha::min<uint32_t>(
|
||||
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
|
||||
divUp(maxSeqLen, gemm0CtaTileNbTokens));
|
||||
}
|
||||
|
||||
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -3268,6 +3483,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
||||
{
|
||||
@ -3286,22 +3507,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
uint32_t const nbVHeads = nbKHeads;
|
||||
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
||||
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
|
||||
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
|
||||
{
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
float const factor = 0.25f;
|
||||
return mha::min<uint32_t>(
|
||||
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
|
||||
divUp(maxSeqLen, gemm0CtaTileNbTokens));
|
||||
}();
|
||||
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqHopperF8MHA(prop, batchSize, nbKHeads, maxSeqLen);
|
||||
#if SPEC_DEC
|
||||
uint32_t const qSeqLen = specDecParams.qSeqLen;
|
||||
#else
|
||||
@ -3371,6 +3577,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#endif
|
||||
#if SPEC_DEC
|
||||
specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
skippedBlockCount, totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
semaphores, scratch);
|
||||
#else
|
||||
|
||||
@ -1272,6 +1272,19 @@ using is_void = is_same<remove_cv_t<T>, void>;
|
||||
template <typename T>
|
||||
inline constexpr bool is_void_v = is_void<T>::value;
|
||||
#endif
|
||||
|
||||
#ifndef GENERATE_CUBIN
|
||||
template <typename T1, typename T2>
|
||||
using pair = std::pair<T1, T2>;
|
||||
#else
|
||||
template <typename T1, typename T2>
|
||||
struct pair
|
||||
{
|
||||
T1 first;
|
||||
T2 second;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace mha
|
||||
|
||||
#if GENERATE_CUBIN
|
||||
|
||||
@ -50,7 +50,8 @@ using Vector = Matrix<Type, Size, 1>;
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
|
||||
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
|
||||
{
|
||||
uint32_t const nbTiles = divUp(seqLen, tileSize);
|
||||
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
|
||||
@ -61,6 +62,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
float const qkScale = qScale * kvScale / sqrtf(validElemsPerHead);
|
||||
uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize);
|
||||
uint32_t const idxTileBeg = seqBeg / tileSize;
|
||||
|
||||
uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1;
|
||||
std::vector<Eigen::Vector<float, headGrpSize>> skipRowMaxs(nbSubSeq);
|
||||
for (uint32_t i = 0; i < nbSubSeq; i++)
|
||||
{
|
||||
skipRowMaxs[i].fill(-INFINITY);
|
||||
}
|
||||
bool const disableSkipForShortSeq = (seqLen < skipSoftmaxThresholdScaleFactor);
|
||||
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / seqLen;
|
||||
|
||||
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
|
||||
{
|
||||
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> gemm0Acc;
|
||||
@ -88,7 +99,22 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::Vector<float, headGrpSize> const tileRowMax = gemm0Acc.rowwise().maxCoeff().cwiseMax(rowMax).eval();
|
||||
Eigen::Vector<float, headGrpSize> const localRowMax = gemm0Acc.rowwise().maxCoeff().eval();
|
||||
Eigen::Vector<float, headGrpSize> const tileRowMax = localRowMax.cwiseMax(rowMax).eval();
|
||||
auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq];
|
||||
skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval();
|
||||
|
||||
if (!disableSkipForShortSeq && skipSoftmaxThreshold > 0)
|
||||
{
|
||||
*totalBlockCount += 1;
|
||||
auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold));
|
||||
bool const skipBlock = skipSoftmaxMask.all() && ((idxTile - idxTileBeg) >= nbSubSeq);
|
||||
if (skipBlock)
|
||||
{
|
||||
*skippedBlockCount += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> tileX
|
||||
= (gemm0Acc.colwise() - tileRowMax).array().exp().eval();
|
||||
@ -138,7 +164,8 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
||||
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, \
|
||||
float skipSoftmaxThreshold, uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
|
||||
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);
|
||||
|
||||
@ -88,7 +88,8 @@ struct CacheSeq<true, true>
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
|
||||
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum);
|
||||
|
||||
template <typename MathElem, bool isPaged, bool useBeamSearch>
|
||||
#if SPEC_DEC
|
||||
|
||||
@ -150,7 +150,8 @@ template <uint32_t nbKHeads>
|
||||
#endif
|
||||
#endif
|
||||
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
|
||||
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
|
||||
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30,
|
||||
float skipSoftmaxThresholdScaleFactor = 0.0f)
|
||||
{
|
||||
#if IS_MLA
|
||||
if (nbKHeads != 1)
|
||||
@ -224,6 +225,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head.
|
||||
}
|
||||
ctxLen = std::min(ctxLen, seqLen);
|
||||
uint32_t skippedBlockCount = 0;
|
||||
uint32_t totalBlockCount = 0;
|
||||
if (skipSoftmaxThresholdScaleFactor > 0)
|
||||
{
|
||||
assert(useQGMMA);
|
||||
}
|
||||
float const kScale = cacheElemSize == 2 ? 1.f : 1 / 4.f;
|
||||
float const vScale = kScale;
|
||||
float const qScale = 1.f;
|
||||
@ -329,6 +336,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
auto const rcpOutScale = ManagedMemBuf<float>(1);
|
||||
auto const seqLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
||||
auto const ctxLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
auto const kernelSkippedBlockCount = ManagedMemBuf<uint32_t>(1);
|
||||
auto const kernelTotalBlockCount = ManagedMemBuf<uint32_t>(1);
|
||||
kernelSkippedBlockCount[0] = 0;
|
||||
kernelTotalBlockCount[0] = 0;
|
||||
#endif
|
||||
#else
|
||||
EXPECT_EQ(skipSoftmaxThresholdScaleFactor, 0.0f)
|
||||
<< "Got non-zero skipSoftmaxThresholdScaleFactor while SKIP_SOFTMAX_ATTN is not enabled.";
|
||||
#endif
|
||||
#if USE_PAGED_KV_CACHE
|
||||
auto const pageListBuf = ManagedMemBuf<std::byte>(pageListBytes);
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
@ -726,6 +744,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
maxSeqLen, &seqLenList[0][0], batchSize, kvCacheScale.get(), semaphores.get(), scratch, stream);
|
||||
};
|
||||
#else
|
||||
auto multiBlockNum = [&]()
|
||||
{
|
||||
auto const calcFunc = useQGMMA ? &computeNbSubSeqPerSeqHopperF8MHA : &computeNbSubSeqPerSeqMHA;
|
||||
return calcFunc(prop, batchSize, nbKHeads, maxSeqLen);
|
||||
}();
|
||||
auto runKernel = [&]()
|
||||
{
|
||||
auto const launchFunc = useQGMMA ? &launchHopperF8MHA : &launchMHA;
|
||||
@ -776,6 +799,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
batchSize, kvCacheScale.get(),
|
||||
#if SPEC_DEC
|
||||
specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
kernelSkippedBlockCount.get(), kernelTotalBlockCount.get(),
|
||||
#endif
|
||||
#endif
|
||||
semaphores.get(), scratch, stream);
|
||||
checkCuda(cudaGetLastError());
|
||||
@ -813,6 +842,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
checkCuda(cudaEventRecord(toc, stream));
|
||||
prefetchToDevice(cudaCpuDeviceId);
|
||||
checkCuda(cudaStreamSynchronize(stream));
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
kernelSkippedBlockCount[0] /= nbIters;
|
||||
kernelTotalBlockCount[0] /= nbIters;
|
||||
#endif
|
||||
if (testPerf)
|
||||
{
|
||||
float ms;
|
||||
@ -849,6 +882,15 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
= totalNbCacheLoadBytes + inputBytes + outputBytes; // we ignore page indices and beam search indices.
|
||||
float const dramSolTime = totalTraffic / bandwidth * 1E3f;
|
||||
float const dramSolRatio = dramSolTime / ms;
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
|
||||
* (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]))
|
||||
* nbLoadedCacheTokens;
|
||||
float const totalTrafficWithSkip
|
||||
= totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices.
|
||||
float const dramSolTimeWithSkip = totalTrafficWithSkip / bandwidth * 1E3f;
|
||||
float const dramSolRatioWithSkip = dramSolTimeWithSkip / ms;
|
||||
#endif
|
||||
if (verbose)
|
||||
{
|
||||
printf("done\n");
|
||||
@ -863,7 +905,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
}
|
||||
float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
|
||||
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
|
||||
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
|
||||
printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, tops);
|
||||
#else
|
||||
printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops);
|
||||
#endif
|
||||
}
|
||||
if (refCheck)
|
||||
{
|
||||
@ -1084,8 +1132,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
if (useQGMMA)
|
||||
{
|
||||
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
|
||||
refAttentionSinks);
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks,
|
||||
skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum);
|
||||
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
}
|
||||
@ -1132,6 +1180,14 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
printf("host skippedBlockCount: %d/%d (%.2f%%)\n", skippedBlockCount, totalBlockCount,
|
||||
totalBlockCount == 0 ? 0.0f : 100.0f * skippedBlockCount / totalBlockCount);
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
|
||||
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
|
||||
#endif
|
||||
#endif
|
||||
if (saveData)
|
||||
{
|
||||
fout_refOutput.close();
|
||||
@ -1253,6 +1309,14 @@ TEST(RefCheck, llama_V2_70b)
|
||||
#if SLIDING_WINDOW
|
||||
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
|
||||
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
runTest<1>(32, 2048, false, true, false, false, false, ~0U, 1U << 30, 0.f);
|
||||
runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 1280.f);
|
||||
runTest<2>(32, 4096, false, true, false, false, false, ~0U, 1U << 30, 125.f);
|
||||
runTest<4>(32, 300, false, true, false, false, false, ~0U, 1U << 30, 80.f);
|
||||
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 501.0f);
|
||||
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 500.f);
|
||||
#endif
|
||||
runTest<8>(120, 367, false, true);
|
||||
runTest<8>(1792, 2048, false, true);
|
||||
|
||||
@ -157,6 +157,11 @@ set(UCX_WRAPPER_TARGET tensorrt_llm_ucx_wrapper)
|
||||
|
||||
if(NIXL_ROOT)
|
||||
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
|
||||
set(TRANSFER_AGENT_BINDING_TARGET tensorrt_llm_transfer_agent_binding)
|
||||
endif()
|
||||
|
||||
if(MOONCAKE_ROOT)
|
||||
set(MOONCAKE_WRAPPER_TARGET tensorrt_llm_mooncake_wrapper)
|
||||
endif()
|
||||
|
||||
add_subdirectory(executor)
|
||||
@ -272,6 +277,11 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
|
||||
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
|
||||
endif()
|
||||
|
||||
if(TARGET ${MOONCAKE_WRAPPER_TARGET})
|
||||
target_link_libraries(${MOONCAKE_WRAPPER_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
add_dependencies(${SHARED_TARGET} ${MOONCAKE_WRAPPER_TARGET})
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
# Load libraries at $PREFIX/lib from
|
||||
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs
|
||||
|
||||
@ -154,7 +154,8 @@ bool CacheFormatter::needSendCache(
|
||||
return true;
|
||||
}
|
||||
|
||||
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
||||
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
|
||||
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
|
||||
int selfTpRankInDpGroup = selfTpRank;
|
||||
if (selfConfig.getParallelConfig().mEnableAttentionDP)
|
||||
{
|
||||
|
||||
@ -81,6 +81,11 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
|
||||
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
|
||||
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
|
||||
}
|
||||
else if (common::getEnvUseMooncakeKvCache())
|
||||
{
|
||||
backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE;
|
||||
TLLM_LOG_INFO("Enable MOONCAKE KV cache transport.");
|
||||
}
|
||||
else if (common::getEnvUseMPIKvCache())
|
||||
{
|
||||
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
|
||||
@ -203,9 +208,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
|
||||
{
|
||||
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
|
||||
mCacheTransBufferManagerPtrs, *mCacheState);
|
||||
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
|
||||
TLLM_LOG_INFO("NIXL Connection Manager created");
|
||||
}
|
||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
|
||||
{
|
||||
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
|
||||
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
|
||||
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
|
||||
}
|
||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
|
||||
{
|
||||
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
|
||||
|
||||
@ -358,8 +358,9 @@ public:
|
||||
|
||||
TransceiverTag::Id id;
|
||||
RequestInfo info;
|
||||
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
|
||||
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
|
||||
auto const* connection = isAgent
|
||||
? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate)
|
||||
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id));
|
||||
if (connection == nullptr && !mManager->isRunning())
|
||||
{
|
||||
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
|
||||
@ -395,8 +396,8 @@ public:
|
||||
if (it == mRequestToSession.end())
|
||||
{
|
||||
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
|
||||
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
|
||||
DataContext{tagFromRequestId(requestId), mTerminate}, mSelfState, info.getTransState(),
|
||||
mBufferManager, info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
|
||||
!common::getEnvKVCacheTimeOutputPath().empty());
|
||||
session.setTime(TransferSession::kTimeRequestInfo);
|
||||
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
||||
@ -685,6 +686,10 @@ private:
|
||||
{
|
||||
future.get();
|
||||
}
|
||||
if (mResponseFuture.valid())
|
||||
{
|
||||
mResponseFuture.get();
|
||||
}
|
||||
}
|
||||
|
||||
void removeResponse(std::map<RequestIdType, Response>::iterator it)
|
||||
@ -886,9 +891,9 @@ public:
|
||||
}
|
||||
}
|
||||
auto const& resource = getReceiveCacheResource(llmRequest);
|
||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
||||
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
|
||||
&llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
|
||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId), mTerminate},
|
||||
mSelfState, contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(),
|
||||
requestInfo.getLastBlockKey(), &llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
|
||||
}
|
||||
|
||||
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
|
||||
@ -964,7 +969,7 @@ public:
|
||||
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
|
||||
TLLM_CHECK(agentConnection);
|
||||
isReady = agentConnection->recvReadySignal(
|
||||
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
|
||||
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG, mTerminate});
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -979,6 +984,7 @@ public:
|
||||
|
||||
~Impl()
|
||||
{
|
||||
mTerminate.store(true);
|
||||
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
|
||||
{
|
||||
asyncResource->mTerminate = true;
|
||||
@ -1134,6 +1140,7 @@ private:
|
||||
runtime::BufferManager mBufferManager;
|
||||
std::ofstream mMeasuresFile;
|
||||
std::mutex mMeasuresFileMutex;
|
||||
std::atomic<bool> mTerminate{false};
|
||||
};
|
||||
|
||||
void CacheSender::ImplDeleter::operator()(Impl* ptr)
|
||||
|
||||
@ -1224,7 +1224,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end()
|
||||
? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse)
|
||||
: std::make_tuple(false, 0, nullptr);
|
||||
if (matchingBlock != nullptr)
|
||||
if (matchingBlock != nullptr && numMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen())
|
||||
{
|
||||
KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId();
|
||||
|
||||
@ -1338,6 +1338,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
}
|
||||
}
|
||||
|
||||
sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens);
|
||||
return numMatchedTokens;
|
||||
}
|
||||
|
||||
@ -1555,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
||||
std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
|
||||
{
|
||||
SizeType32 numBlocksStoredForReuse = 0;
|
||||
@ -1568,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
|
||||
|
||||
auto numBlocks = blockKeys.size();
|
||||
std::vector<BlockPtr> storedBlocks;
|
||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
||||
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
|
||||
{
|
||||
auto const bid = blockIds[blockCnt];
|
||||
@ -1619,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
|
||||
if (pinBlocks)
|
||||
{
|
||||
searchRoot->incRefCount();
|
||||
pinnedBlockIds.push_back(searchRoot->getBlockId());
|
||||
}
|
||||
lastStoredId = searchRoot->getBlockId();
|
||||
}
|
||||
if (mEventManager)
|
||||
{
|
||||
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
|
||||
}
|
||||
return {numBlocksStoredForReuse, lastStoredId};
|
||||
return {numBlocksStoredForReuse, pinnedBlockIds};
|
||||
}
|
||||
|
||||
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
|
||||
@ -1714,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
|
||||
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
||||
std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||
{
|
||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
||||
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||
for (auto& [_, manager] : mWindowBlockManagers)
|
||||
{
|
||||
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
}
|
||||
return lastStoredId;
|
||||
return pinnedBlockIds;
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
||||
@ -1731,9 +1732,22 @@ std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
||||
// Released block will be stored when reuse is enabled.
|
||||
// Reuse is implied to be enabled if llmRequest is provided.
|
||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
||||
|
||||
// For now, the attention kernel only accepts a single
|
||||
// "prepopulatedPromptLen", that is, all window sizes will use the same
|
||||
// prepopulated prompt length, so it is meaningless right now to save
|
||||
// blocks only for a certain window size while blocks in the other
|
||||
// window size are not valid for saving for reuse.
|
||||
bool isAllWindowSizesValidForStoreForReuse = true;
|
||||
for (auto& [windowSize, manager] : mWindowBlockManagers)
|
||||
{
|
||||
isAllWindowSizesValidForStoreForReuse &= manager.isSequenceValidForStoreForReuse(sequence.getRequestId());
|
||||
}
|
||||
|
||||
for (auto& [_, manager] : mWindowBlockManagers)
|
||||
{
|
||||
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1)
|
||||
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1
|
||||
|| !isAllWindowSizesValidForStoreForReuse)
|
||||
{
|
||||
lastStoredId = manager.releaseBlocks(sequence, std::nullopt);
|
||||
}
|
||||
@ -1753,7 +1767,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
|
||||
}
|
||||
}
|
||||
|
||||
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||
{
|
||||
// Use the first window size
|
||||
if (mWindowBlockManagers.empty())
|
||||
@ -1761,7 +1775,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
return;
|
||||
}
|
||||
auto& firstManager = mWindowBlockManagers.begin()->second;
|
||||
firstManager.unpinBlocksById(blockId);
|
||||
firstManager.unpinBlocksById(blockIds);
|
||||
}
|
||||
|
||||
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
||||
@ -1774,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
||||
}
|
||||
}
|
||||
|
||||
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||
{
|
||||
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
|
||||
if (blockIds.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto block = mAllBlocksById[blockId];
|
||||
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
||||
|
||||
for (auto const& blockId : blockIds)
|
||||
{
|
||||
block->decRefCount();
|
||||
if (!block->hasRefs())
|
||||
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
|
||||
"Block id %d is out of range", blockId);
|
||||
auto block = mAllBlocksById[blockId];
|
||||
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
||||
{
|
||||
mEvictionPolicy->releaseBlock(block);
|
||||
block->decRefCount();
|
||||
if (!block->hasRefs())
|
||||
{
|
||||
mEvictionPolicy->releaseBlock(block);
|
||||
}
|
||||
}
|
||||
block = std::move(block->getPrevBlock());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1856,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
|
||||
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||
std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||
{
|
||||
auto constexpr beamIdx = 0;
|
||||
@ -1869,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
|
||||
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
|
||||
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
||||
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
|
||||
|
||||
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
|
||||
|
||||
return pinnedBlockIds;
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
||||
@ -1908,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
||||
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
|
||||
[](BlockPtr const& block) { return block->getBlockId(); });
|
||||
|
||||
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
||||
auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
||||
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
|
||||
sequence.getRequestId(), numBlocksStoredForReuse);
|
||||
}
|
||||
@ -2485,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
|
||||
return lastStoredId;
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
||||
std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
||||
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||
{
|
||||
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||
auto& sequence = getSequence(requestId);
|
||||
std::optional<KVCacheBlock::IdType> lastStoredId
|
||||
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||
return lastStoredId;
|
||||
return pinnedBlockIds;
|
||||
}
|
||||
|
||||
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
|
||||
@ -2508,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
|
||||
mBlockManager.pinBlocks(sequence);
|
||||
}
|
||||
|
||||
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||
{
|
||||
mBlockManager.unpinBlocksById(blockId);
|
||||
mBlockManager.unpinBlocksById(blockIds);
|
||||
}
|
||||
|
||||
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
|
||||
|
||||
@ -60,7 +60,8 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
|
||||
bool MLACacheFormatter::needSendCache(
|
||||
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
|
||||
{
|
||||
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
||||
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
|
||||
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
|
||||
|
||||
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
|
||||
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
|
||||
|
||||
@ -296,7 +296,13 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
// Parameters for sparse attention
|
||||
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
|
||||
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
|
||||
|
||||
// Skip softmax threshold.
|
||||
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
// Statistics of skip-softmax, pointers of device memory for output
|
||||
xqaParams.skip_softmax_total_blocks = mSkipSoftmaxTotalBlocks;
|
||||
xqaParams.skip_softmax_skipped_blocks = mSkipSoftmaxSkippedBlocks;
|
||||
#endif
|
||||
// Cross attention parameters.
|
||||
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
|
||||
|
||||
@ -1313,6 +1319,8 @@ int AttentionOp::mlaGeneration(
|
||||
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
|
||||
}
|
||||
|
||||
// MLA does not support skip-softmax attention right now
|
||||
|
||||
// Run the fmha kernel
|
||||
mDecoderFMHARunner->run(fmhaParams);
|
||||
}
|
||||
@ -1885,6 +1893,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
|
||||
}
|
||||
|
||||
// Skip-softmax attention parameters
|
||||
fmhaParams.skipSoftmaxThresholdScaleFactor = mSkipSoftmaxThresholdScaleFactorPrefill;
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
fmhaParams.skipSoftmaxTotalBlocks = mSkipSoftmaxTotalBlocks;
|
||||
fmhaParams.skipSoftmaxSkippedBlocks = mSkipSoftmaxSkippedBlocks;
|
||||
#else
|
||||
if (tensorrt_llm::common::getEnvPrintSkipSoftmaxStat())
|
||||
{
|
||||
TLLM_THROW("To print skip softmax stat, please run build_wheel.py with -DSKIP_SOFTMAX_STAT");
|
||||
}
|
||||
#endif
|
||||
|
||||
if (mAttentionChunkSize)
|
||||
{
|
||||
fmhaParams.chunkedAttentionSize = *mAttentionChunkSize;
|
||||
|
||||
@ -494,6 +494,14 @@ public:
|
||||
// See [Chunked Attention] in _torch/modules/attention.py
|
||||
std::optional<int64_t> mAttentionChunkSize = std::nullopt;
|
||||
|
||||
// Skip softmax threshold scale factor.
|
||||
float mSkipSoftmaxThresholdScaleFactorPrefill = 0;
|
||||
float mSkipSoftmaxThresholdScaleFactorDecode = 0;
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
uint32_t* mSkipSoftmaxTotalBlocks;
|
||||
uint32_t* mSkipSoftmaxSkippedBlocks;
|
||||
#endif
|
||||
|
||||
[[nodiscard]] auto data() const
|
||||
{
|
||||
return std::make_tuple(mLayerIdx, mNumHeads, mVisionStart, mVisionLength, mNumKVHeads, mHeadSize,
|
||||
@ -510,7 +518,8 @@ public:
|
||||
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
|
||||
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
|
||||
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
|
||||
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
|
||||
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1), mSkipSoftmaxThresholdScaleFactorPrefill,
|
||||
mSkipSoftmaxThresholdScaleFactorDecode);
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
@ -43,7 +43,7 @@ template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_
|
||||
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
|
||||
@ -63,7 +63,7 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
||||
{
|
||||
return common::getEnvAllReduceWorkspaceSize();
|
||||
}
|
||||
if (worldSize <= 2)
|
||||
char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE");
|
||||
if (envWorkspaceSize != nullptr)
|
||||
{
|
||||
return 16 * 1000 * 1000;
|
||||
}
|
||||
return 8 * 1000 * 1000;
|
||||
}
|
||||
|
||||
// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold)
|
||||
inline std::unordered_map<int, std::unordered_map<int, std::pair<size_t, size_t>>> HeuristicThresholdLP{
|
||||
{90,
|
||||
{
|
||||
{2, {4096, 4096 * 4096}},
|
||||
{4, {4096, 1024 * 1024}},
|
||||
{8, {2048, 512 * 512}},
|
||||
}},
|
||||
{100,
|
||||
{
|
||||
{2, {4096, 4096 * 4096}},
|
||||
{4, {4096, 1024 * 2048}},
|
||||
{8, {4096, 1024 * 1024}},
|
||||
}},
|
||||
};
|
||||
|
||||
inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op)
|
||||
{
|
||||
// The heuristic is based on the following assumptions:
|
||||
// __________________________________
|
||||
// | \ TWO-SHOT zone |
|
||||
// | ONE-SHOT zone \ | NCCL zone
|
||||
// |_______________________\______|___
|
||||
// sm_major is 90 or 100
|
||||
|
||||
auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion()));
|
||||
|
||||
auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size];
|
||||
auto const message_size = seq_len * hidden_size;
|
||||
if (message_size >= two_shot_numel_threshold)
|
||||
{
|
||||
return AllReduceStrategyType::TWOSHOT;
|
||||
}
|
||||
else
|
||||
{
|
||||
return AllReduceStrategyType::ONESHOT;
|
||||
return static_cast<size_t>(std::atoi(envWorkspaceSize));
|
||||
}
|
||||
return 67108864; // 64 MiB
|
||||
}
|
||||
|
||||
// use 1D vector to store the best strategy instead of a map for each sm version
|
||||
|
||||
@ -249,7 +249,7 @@ bool getEnvUseTileSizeKv64ForTrtllmGen()
|
||||
bool getEnvEnablePDL()
|
||||
{
|
||||
static std::once_flag flag;
|
||||
static bool enablePDL = false;
|
||||
static bool enablePDL = true;
|
||||
|
||||
std::call_once(flag,
|
||||
[&]()
|
||||
@ -257,7 +257,18 @@ bool getEnvEnablePDL()
|
||||
if (getSMVersion() >= 90)
|
||||
{
|
||||
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
||||
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
|
||||
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
|
||||
if (env)
|
||||
{
|
||||
if (env[0] == '1' && env[1] == '\0')
|
||||
{
|
||||
enablePDL = true;
|
||||
}
|
||||
else if (env[0] == '0' && env[1] == '\0')
|
||||
{
|
||||
enablePDL = false;
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
return enablePDL;
|
||||
@ -281,6 +292,12 @@ bool getEnvUseNixlKvCache()
|
||||
return useNixlKvCache;
|
||||
}
|
||||
|
||||
bool getEnvUseMooncakeKvCache()
|
||||
{
|
||||
static bool const useMooncakeKvCache = getBoolEnv("TRTLLM_USE_MOONCAKE_KVCACHE");
|
||||
return useMooncakeKvCache;
|
||||
}
|
||||
|
||||
bool getEnvUseRoundRobinBlockDistForCP()
|
||||
{
|
||||
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
|
||||
@ -343,6 +360,23 @@ std::string getEnvNixlBackend()
|
||||
return nixlBackend;
|
||||
}
|
||||
|
||||
std::string getEnvMooncakeInterface()
|
||||
{
|
||||
static std::once_flag flag;
|
||||
static std::string mooncakeInterface;
|
||||
|
||||
std::call_once(flag,
|
||||
[&]()
|
||||
{
|
||||
char const* mooncake_interface = std::getenv("TRTLLM_MOONCAKE_INTERFACE");
|
||||
if (mooncake_interface)
|
||||
{
|
||||
mooncakeInterface = mooncake_interface;
|
||||
}
|
||||
});
|
||||
return mooncakeInterface;
|
||||
}
|
||||
|
||||
bool getEnvDisaggLayerwise()
|
||||
{
|
||||
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
|
||||
@ -531,6 +565,11 @@ bool getEnvEplbForceGdrcopy()
|
||||
return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY");
|
||||
}
|
||||
|
||||
bool getEnvPrintSkipSoftmaxStat()
|
||||
{
|
||||
return getBoolEnv("TRTLLM_PRINT_SKIP_SOFTMAX_STAT");
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
|
||||
@ -83,8 +83,11 @@ inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 g
|
||||
bool getEnvUseUCXKvCache();
|
||||
|
||||
bool getEnvUseMPIKvCache();
|
||||
|
||||
bool getEnvUseNixlKvCache();
|
||||
|
||||
bool getEnvUseMooncakeKvCache();
|
||||
|
||||
bool getEnvUseRoundRobinBlockDistForCP();
|
||||
|
||||
std::string getEnvUCXInterface();
|
||||
@ -93,6 +96,8 @@ std::string getEnvNixlInterface();
|
||||
|
||||
std::string getEnvNixlBackend();
|
||||
|
||||
std::string getEnvMooncakeInterface();
|
||||
|
||||
bool getEnvDisaggLayerwise();
|
||||
|
||||
bool getEnvParallelCacheSend();
|
||||
@ -156,6 +161,8 @@ bool getEnvKVCacheTransferAllBlocksForWindow();
|
||||
|
||||
bool getEnvEplbForceGdrcopy();
|
||||
|
||||
bool getEnvPrintSkipSoftmaxStat();
|
||||
|
||||
} // namespace common
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
|
||||
226
cpp/tensorrt_llm/common/ipUtils.cpp
Normal file
226
cpp/tensorrt_llm/common/ipUtils.cpp
Normal file
@ -0,0 +1,226 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ipUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <dirent.h>
|
||||
#include <fcntl.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <string>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace common
|
||||
{
|
||||
|
||||
std::string getLocalIpByNic(std::string const& interface, int rank)
|
||||
{
|
||||
struct ifaddrs* ifaddr = nullptr;
|
||||
if (getifaddrs(&ifaddr) == -1)
|
||||
{
|
||||
TLLM_LOG_ERROR(rank,
|
||||
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is "
|
||||
"set "
|
||||
"correctly.");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
for (struct ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next)
|
||||
{
|
||||
if (ifa->ifa_addr == nullptr)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ifa->ifa_name == interface)
|
||||
{
|
||||
if (ifa->ifa_addr->sa_family == AF_INET)
|
||||
{
|
||||
char ip[INET_ADDRSTRLEN]{};
|
||||
void* addr = &((reinterpret_cast<struct sockaddr_in*>(ifa->ifa_addr))->sin_addr);
|
||||
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "0.0.0.0") != 0)
|
||||
{
|
||||
freeifaddrs(ifaddr);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
else if (ifa->ifa_addr->sa_family == AF_INET6)
|
||||
{
|
||||
char ip[INET6_ADDRSTRLEN]{};
|
||||
void* addr = &((reinterpret_cast<struct sockaddr_in6*>(ifa->ifa_addr))->sin6_addr);
|
||||
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
|
||||
&& std::strcmp(ip, "::1") != 0)
|
||||
{
|
||||
freeifaddrs(ifaddr);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
freeifaddrs(ifaddr);
|
||||
TLLM_LOG_ERROR(
|
||||
rank, "Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is set correctly.");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
std::string getLocalIpByHostname(int rank)
|
||||
{
|
||||
char hostname[256]{};
|
||||
if (gethostname(hostname, sizeof(hostname)) == -1)
|
||||
{
|
||||
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
struct addrinfo hints = {};
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
hints.ai_flags = AI_CANONNAME;
|
||||
|
||||
struct addrinfo* res = nullptr;
|
||||
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
|
||||
{
|
||||
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
for (struct addrinfo* p = res; p != nullptr; p = p->ai_next)
|
||||
{
|
||||
|
||||
if (p->ai_family == AF_INET)
|
||||
{ // IPv4
|
||||
char ip[INET_ADDRSTRLEN]{};
|
||||
struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
|
||||
void* addr = &(ipv4->sin_addr);
|
||||
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "127.0.0.1") != 0
|
||||
&& std::strcmp(ip, "0.0.0.0") != 0)
|
||||
{
|
||||
freeaddrinfo(res);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
else if (p->ai_family == AF_INET6)
|
||||
{ // IPv6
|
||||
char ip[INET6_ADDRSTRLEN]{};
|
||||
struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
|
||||
void* addr = &(ipv6->sin6_addr);
|
||||
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
|
||||
&& std::strcmp(ip, "::1") != 0)
|
||||
{
|
||||
freeaddrinfo(res);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
freeaddrinfo(res);
|
||||
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
|
||||
return std::string{};
|
||||
}
|
||||
|
||||
std::string getLocalIpByRemoteOrHostName(int rank)
|
||||
{
|
||||
|
||||
// Try IPv4
|
||||
struct sockaddr_in addr
|
||||
{
|
||||
};
|
||||
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(80);
|
||||
// using google's public dns server to get the local ip which can be accessed from remote
|
||||
char const* dns_ip_v4 = "8.8.8.8";
|
||||
inet_pton(AF_INET, dns_ip_v4, &addr.sin_addr);
|
||||
|
||||
int sock = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (sock != -1)
|
||||
{
|
||||
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != -1)
|
||||
{
|
||||
socklen_t addr_len = sizeof(addr);
|
||||
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &addr_len) != -1)
|
||||
{
|
||||
char ip[INET_ADDRSTRLEN]{};
|
||||
inet_ntop(AF_INET, &addr.sin_addr, ip, sizeof(ip));
|
||||
close(sock);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
close(sock);
|
||||
}
|
||||
|
||||
// Try IPv6
|
||||
struct sockaddr_in6 addr6
|
||||
{
|
||||
};
|
||||
|
||||
addr6.sin6_family = AF_INET6;
|
||||
addr6.sin6_port = htons(80);
|
||||
// using google's public dns server
|
||||
char const* dns_ipv6 = "2001:4860:4860::8888";
|
||||
inet_pton(AF_INET6, dns_ipv6, &addr6.sin6_addr);
|
||||
|
||||
sock = socket(AF_INET6, SOCK_DGRAM, 0);
|
||||
if (sock != -1)
|
||||
{
|
||||
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr6), sizeof(addr6)) != -1)
|
||||
{
|
||||
socklen_t addr_len = sizeof(addr6);
|
||||
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr6), &addr_len) != -1)
|
||||
{
|
||||
char ip[INET6_ADDRSTRLEN]{};
|
||||
inet_ntop(AF_INET6, &addr6.sin6_addr, ip, sizeof(ip));
|
||||
close(sock);
|
||||
return std::string(ip);
|
||||
}
|
||||
}
|
||||
close(sock);
|
||||
}
|
||||
|
||||
// Try hostname
|
||||
return getLocalIpByHostname(rank);
|
||||
}
|
||||
|
||||
std::string getLocalIp(std::string interface, int rank)
|
||||
{
|
||||
std::string localIP = {};
|
||||
if (!interface.empty())
|
||||
{
|
||||
localIP = getLocalIpByNic(interface, rank);
|
||||
}
|
||||
if (localIP.empty())
|
||||
{
|
||||
localIP = getLocalIpByRemoteOrHostName(rank);
|
||||
}
|
||||
// check whether the localIP is valid
|
||||
if (localIP.empty())
|
||||
{
|
||||
TLLM_THROW("getLocalIp: Can't get local ip");
|
||||
}
|
||||
return localIP;
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
28
cpp/tensorrt_llm/common/ipUtils.h
Normal file
28
cpp/tensorrt_llm/common/ipUtils.h
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include <string>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace common
|
||||
{
|
||||
std::string getLocalIp(std::string interface, int rank);
|
||||
} // namespace common
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -37,6 +37,46 @@ NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept
|
||||
return instance;
|
||||
}
|
||||
|
||||
NcclCommResourceManager::~NcclCommResourceManager()
|
||||
{
|
||||
// Mark that we're in destruction to prevent cleanup attempts from deleters
|
||||
// that may run during static destruction
|
||||
mIsDestroying.store(true, std::memory_order_release);
|
||||
|
||||
// Proactively clean up all resources before destruction
|
||||
// This ensures cleanup happens in a controlled manner before static destruction
|
||||
std::vector<std::pair<ncclComm_t, std::vector<ResourceEntry>>> allResources;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
// Move all resources out of the map
|
||||
allResources.reserve(mCommResources.size());
|
||||
for (auto& [comm, resources] : mCommResources)
|
||||
{
|
||||
allResources.emplace_back(comm, std::move(resources));
|
||||
}
|
||||
mCommResources.clear();
|
||||
}
|
||||
|
||||
// Clean up all resources outside the lock
|
||||
// Note: We don't call ncclCommDestroy here - that's the responsibility
|
||||
// of the shared_ptr deleter. We just clean up registered resources.
|
||||
for (auto& [comm, resources] : allResources)
|
||||
{
|
||||
for (auto& [cleanup, name] : resources)
|
||||
{
|
||||
try
|
||||
{
|
||||
cleanup();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore exceptions during destruction
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
|
||||
{
|
||||
if (!comm)
|
||||
@ -60,23 +100,56 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if we're in the process of being destroyed
|
||||
// If so, skip cleanup - the destructor will handle it proactively
|
||||
if (mIsDestroying.load(std::memory_order_acquire))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<ResourceEntry> resourcesToClean;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto it = mCommResources.find(comm);
|
||||
if (it == mCommResources.end())
|
||||
// During static destruction, mutex and logging may not be safe.
|
||||
// Use try-catch to handle any issues gracefully.
|
||||
try
|
||||
{
|
||||
// Nothing registered for this comm, nothing to clean up
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
// Double-check after acquiring lock (destruction may have started)
|
||||
if (mIsDestroying.load(std::memory_order_acquire))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto it = mCommResources.find(comm);
|
||||
if (it == mCommResources.end())
|
||||
{
|
||||
// Nothing registered for this comm, nothing to clean up
|
||||
return;
|
||||
}
|
||||
|
||||
// Move resources out (preserves order) and remove from map
|
||||
resourcesToClean = std::move(it->second);
|
||||
mCommResources.erase(it);
|
||||
|
||||
// Logging may fail during static destruction, so wrap in try-catch
|
||||
try
|
||||
{
|
||||
TLLM_LOG_TRACE("[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(),
|
||||
static_cast<void*>(comm));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// If mutex access fails during static destruction, just return.
|
||||
// This prevents segfaults when the singleton is being destroyed.
|
||||
return;
|
||||
}
|
||||
|
||||
// Move resources out (preserves order) and remove from map
|
||||
resourcesToClean = std::move(it->second);
|
||||
mCommResources.erase(it);
|
||||
|
||||
TLLM_LOG_TRACE(
|
||||
"[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast<void*>(comm));
|
||||
}
|
||||
|
||||
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
|
||||
@ -85,19 +158,41 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
TLLM_LOG_TRACE(
|
||||
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
|
||||
// Logging may fail during static destruction, so wrap in try-catch
|
||||
try
|
||||
{
|
||||
TLLM_LOG_TRACE(
|
||||
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
cleanup();
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(),
|
||||
static_cast<void*>(comm), e.what());
|
||||
try
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s",
|
||||
name.c_str(), static_cast<void*>(comm), e.what());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
|
||||
name.c_str(), static_cast<void*>(comm));
|
||||
try
|
||||
{
|
||||
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
|
||||
name.c_str(), static_cast<void*>(comm));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
@ -139,12 +140,13 @@ public:
|
||||
|
||||
private:
|
||||
NcclCommResourceManager() = default;
|
||||
~NcclCommResourceManager() = default;
|
||||
~NcclCommResourceManager();
|
||||
|
||||
using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;
|
||||
|
||||
mutable std::mutex mMutex;
|
||||
std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
|
||||
std::atomic<bool> mIsDestroying{false};
|
||||
};
|
||||
|
||||
// RAII helper to register a resource with a NCCL communicator.
|
||||
|
||||
@ -123,13 +123,24 @@ std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
|
||||
if (*comm)
|
||||
{
|
||||
// Clean up all registered resources FIRST
|
||||
// The cleanupResources function uses a destruction guard to safely handle
|
||||
// static destruction order issues - it will return early if the singleton
|
||||
// is being destroyed (in which case the destructor handles cleanup proactively)
|
||||
tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm);
|
||||
|
||||
// Now destroy the NCCL communicator
|
||||
ncclResult_t result = ncclCommDestroy(*comm);
|
||||
if (result != ncclSuccess)
|
||||
{
|
||||
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
|
||||
// Logging may fail during static destruction, so wrap in try-catch
|
||||
try
|
||||
{
|
||||
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// Ignore logging failures during static destruction
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the communicator value before freeing the pointer
|
||||
|
||||
@ -46,7 +46,7 @@ CUTLASS_DEVICE
|
||||
void launch_dependent_grids()
|
||||
{
|
||||
#if (defined(CUTLASS_GDC_ENABLED))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ CUTLASS_DEVICE
|
||||
void wait_on_dependent_grids()
|
||||
{
|
||||
#if (defined(CUTLASS_GDC_ENABLED))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -686,4 +686,212 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <class Collective>
|
||||
struct MixedInputUtilsSM100
|
||||
{
|
||||
private:
|
||||
using KernelSchedule = typename Collective::KernelSchedule;
|
||||
using ConversionMode = typename Collective::ConversionMode;
|
||||
using SmemLayoutA = typename Collective::SmemLayoutA;
|
||||
using SmemLayoutB = typename Collective::SmemLayoutB;
|
||||
using ElementScale = typename Collective::ElementScale;
|
||||
using ElementZero = typename Collective::ElementZero;
|
||||
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||
|
||||
public:
|
||||
// Helper functions to select packing for conversion
|
||||
template <class SrcType, class DstType, int Cosize>
|
||||
struct select_packing
|
||||
{ // Naive packing policy
|
||||
|
||||
static constexpr auto value()
|
||||
{
|
||||
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
|
||||
}
|
||||
};
|
||||
|
||||
/// (Designed for separate transform pipeline in Blackwell)
|
||||
/// Utilities to dequantize A.
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void dequantize_A_kblock_for_transform(Tensor<EngineIn, LayoutIn> const& tArA,
|
||||
Tensor<EngineOut, LayoutOut>& tArACompute, cute::tuple<Ts...> const& partitioned_extra_info, int const k_block)
|
||||
{
|
||||
|
||||
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
using DstType = typename EngineOut::value_type;
|
||||
|
||||
auto src = tArA(_, _, _, k_block);
|
||||
auto dst = tArACompute(_, _, _, k_block);
|
||||
auto pSrc = raw_pointer_cast(src.data());
|
||||
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
|
||||
constexpr int num_elements = decltype(size(src))::value;
|
||||
|
||||
constexpr int pack = decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||
using DstArray = cutlass::Array<DstType, pack>;
|
||||
constexpr int DstElementsPerReg = 32 / sizeof_bits_v<DstType>;
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, pack / DstElementsPerReg, sizeof(DstArray)>;
|
||||
|
||||
auto src_arr = recast<SrcArray>(src);
|
||||
auto dst_arr = recast<DstArray>(dst);
|
||||
|
||||
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, pack));
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert)
|
||||
{
|
||||
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
|
||||
{
|
||||
|
||||
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementScale>)
|
||||
{
|
||||
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||
|
||||
using ScaleArray = cutlass::Array<ElementScale, pack>;
|
||||
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
|
||||
|
||||
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
|
||||
{
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
|
||||
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||
{
|
||||
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
|
||||
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
|
||||
{
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
|
||||
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
|
||||
constexpr int pack = cute::gcd(pack1, pack2);
|
||||
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||
using DstArray = cutlass::Array<DstType, pack>;
|
||||
using StageArray = cutlass::Array<ElementScale, pack>;
|
||||
constexpr int iters = num_elements / pack;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < iters; ++i)
|
||||
{
|
||||
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
|
||||
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
|
||||
StageArray stageArr;
|
||||
stageArr = Converter1::convert(*pSrcArr);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < pack; ++j)
|
||||
{
|
||||
stageArr[j] = stageArr[j] * scales[i * pack + j];
|
||||
}
|
||||
*pDstArr = Converter2::convert(stageArr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
|
||||
{
|
||||
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
|
||||
|
||||
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
|
||||
auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementZero>)
|
||||
{
|
||||
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||
|
||||
using ScaleArray = cutlass::Array<ElementScale, pack>;
|
||||
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
|
||||
|
||||
using ZeroArray = cutlass::Array<ElementZero, pack>;
|
||||
auto zero_arr = recast<ZeroArray>(filter_zeros(zeros));
|
||||
|
||||
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
|
||||
{
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
|
||||
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, pack));
|
||||
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||
{
|
||||
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
|
||||
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
|
||||
auto&& zero_reg = cute::recast<RegArray>(zeros_vm(_, i))(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
|
||||
{
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
|
||||
bf16x2_val = __hadd2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(zero_reg[ii]));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
|
||||
cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
|
||||
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
|
||||
constexpr int pack = cute::gcd(pack1, pack2);
|
||||
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||
using DstArray = cutlass::Array<DstType, pack>;
|
||||
using StageArray = cutlass::Array<ElementScale, pack>;
|
||||
constexpr int iters = num_elements / pack;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < iters; ++i)
|
||||
{
|
||||
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
|
||||
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
|
||||
StageArray stageArr;
|
||||
stageArr = Converter1::convert(*pSrcArr);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < pack; ++j)
|
||||
{
|
||||
stageArr[j] = stageArr[j] * scales[i * pack + j] + zeros[i * pack + j];
|
||||
}
|
||||
*pDstArr = Converter2::convert(stageArr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
|
||||
"Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace cutlass::gemm::collective::detail
|
||||
|
||||
@ -0,0 +1,294 @@
|
||||
/*
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
|
||||
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
|
||||
int stages>
|
||||
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(StageCount<stages> stage_count)
|
||||
{
|
||||
constexpr int Load2TransformStageCount = stages;
|
||||
constexpr int Transform2MmaStageCount = stages;
|
||||
constexpr int AccumulatorStageCount = stages;
|
||||
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
|
||||
}
|
||||
|
||||
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
|
||||
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
|
||||
int carveout_bytes>
|
||||
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(
|
||||
StageCountAutoCarveout<carveout_bytes> stage_count)
|
||||
{
|
||||
|
||||
constexpr int CtaM = get<0>(CtaTileShape_MNK{});
|
||||
constexpr int CtaN = get<1>(CtaTileShape_MNK{});
|
||||
static_assert(CtaN <= 128, "Can't support CtaN>128 tiles");
|
||||
constexpr int CtaK = get<2>(CtaTileShape_MNK{});
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
|
||||
constexpr int TmemColumns = 512;
|
||||
|
||||
constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K
|
||||
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>;
|
||||
constexpr bool IsAComputeinSmem = !IsAComputeinTmem;
|
||||
|
||||
// Detect 2x2 TMEM layout
|
||||
constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN / 2 : CtaN;
|
||||
constexpr int TmemAWordsPerDP = CtaK / 2;
|
||||
|
||||
constexpr int AccumulatorStageCount
|
||||
= (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP);
|
||||
|
||||
constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32);
|
||||
|
||||
constexpr int TmemInAStageCount_Potential
|
||||
= (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000;
|
||||
|
||||
// Mainload2Transform Pipeline
|
||||
constexpr auto load2transform_pipeline_bytes
|
||||
= sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage);
|
||||
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>; // ElementA introduce here
|
||||
constexpr auto s_bits = cute::is_void_v<ElementScale> ? 0 : cute::sizeof_bits_v<ElementScale>;
|
||||
constexpr auto z_bits = cute::is_void_v<ElementZero> ? 0 : cute::sizeof_bits_v<ElementZero>;
|
||||
|
||||
constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage);
|
||||
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>; // ElementB introduce here
|
||||
|
||||
constexpr int ab_stage_bytes
|
||||
= cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
|
||||
+ cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
|
||||
+ cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
|
||||
+ cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{}))
|
||||
+ static_cast<int>(load2transform_pipeline_bytes) + static_cast<int>(load2mma_pipeline_bytes);
|
||||
|
||||
// Transform2Mma Pipeline
|
||||
constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
|
||||
constexpr auto a_compute_bits = cute::sizeof_bits_v<ElementAMma>;
|
||||
constexpr int ab_compute_stage_bytes = cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem)
|
||||
* size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
|
||||
+ // If ACompute is in TMEM, Acompute buffer has 0 bytes.
|
||||
static_cast<int>(transform2mma_pipeline_bytes);
|
||||
|
||||
constexpr int ABComputeStageCount_Potential
|
||||
= SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes);
|
||||
|
||||
// The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount
|
||||
constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential);
|
||||
|
||||
constexpr int SmemCapacityAfterABComputeCarveout
|
||||
= SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes);
|
||||
|
||||
// Can we boost the number of buffers for A and B?
|
||||
constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes;
|
||||
|
||||
static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2,
|
||||
"Not enough SMEM or TMEM capacity for selected tile size");
|
||||
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Mixed Input MMA kernels builder
|
||||
template <class ElementAOptionalTuple, class GmemLayoutATagTuple, int AlignmentA, class ElementBOptionalTuple,
|
||||
class GmemLayoutBTag, int AlignmentB, class ElementAccumulator,
|
||||
class TileShape_MNK, // The Cluster-level TileShape
|
||||
class ClusterShape_MNK, class StageCountType, class KernelScheduleType>
|
||||
struct CollectiveBuilderSm100WeightOnly<arch::Sm100, arch::OpClassTensorOp,
|
||||
ElementAOptionalTuple, // ElementA
|
||||
GmemLayoutATagTuple, // LayoutA
|
||||
AlignmentA,
|
||||
ElementBOptionalTuple, // ElementB
|
||||
GmemLayoutBTag, // LayoutB
|
||||
AlignmentB, ElementAccumulator,
|
||||
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int)
|
||||
StageCountType, KernelScheduleType,
|
||||
cute::enable_if_t<(cute::is_base_of_v<KernelScheduleSm100MixedInputGemm, KernelScheduleType>) &&(
|
||||
(sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0)
|
||||
&& ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>>
|
||||
{
|
||||
using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>;
|
||||
using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>;
|
||||
|
||||
static constexpr cute::UMMA::Major UmmaMajorA
|
||||
= cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::UMMA::Major UmmaMajorB
|
||||
= cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
|
||||
|
||||
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>;
|
||||
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>;
|
||||
using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>;
|
||||
using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>;
|
||||
|
||||
static constexpr bool NeitherIsTuple
|
||||
= !cute::is_tuple<ElementAOptionalTuple>::value && !cute::is_tuple<ElementBOptionalTuple>::value;
|
||||
static constexpr bool IsANarrow = cute::sizeof_bits_v<ElementA> < cute::sizeof_bits_v<ElementB>;
|
||||
static constexpr bool IsMixedInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
|
||||
static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm.");
|
||||
|
||||
static_assert(
|
||||
(cute::is_tuple<ElementAOptionalTuple>::value ^ cute::is_tuple<ElementBOptionalTuple>::value
|
||||
|| (NeitherIsTuple && (cute::sizeof_bits<ElementA>::value != cute::sizeof_bits<ElementB>::value))),
|
||||
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||
using ElementPairA = cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>,
|
||||
ElementAOptionalTuple>;
|
||||
using ElementPairB = cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>,
|
||||
ElementBOptionalTuple>;
|
||||
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||
static_assert(IsATransformed, "A matrix should be transformed.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type.
|
||||
using ElementMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
using ElementAMma = ElementMma;
|
||||
using ElementBMma = ElementMma;
|
||||
|
||||
static constexpr int IsSubbyteA = cute::sizeof_bits_v<ElementA> < 8;
|
||||
using TmaElementA = cute::conditional_t<IsSubbyteA, uint8_t, ElementA>;
|
||||
|
||||
static constexpr int ScalingFactor = 1;
|
||||
|
||||
using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma<ElementAMma, ElementB,
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB, KernelScheduleType>());
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
|
||||
using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
|
||||
|
||||
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
|
||||
using MmaShapeA_MK = decltype(partition_shape_A(
|
||||
TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
|
||||
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
|
||||
using MmaShapeB_NK = decltype(partition_shape_B(
|
||||
TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
|
||||
|
||||
using BlockTileA_M = decltype(cute::size<0, 0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
|
||||
using BlockTileA_K = decltype(cute::size<0, 1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}));
|
||||
|
||||
// Input transform kernel can not use TMA 2SM instructions.
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA, ElementA,
|
||||
BlockTileA_M, BlockTileA_K>());
|
||||
using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA,
|
||||
ElementAMma, BlockTileA_M, BlockTileA_K>());
|
||||
using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomA,
|
||||
SmemLayoutAtomACompute>;
|
||||
static constexpr int MMA_M = cute::size<0, 0>(MmaShapeA_MK{});
|
||||
using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>,
|
||||
cute::conditional_t<
|
||||
(UmmaMajorA == cute::UMMA::Major::K
|
||||
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>),
|
||||
cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x,
|
||||
SM100_TMEM_STORE_32dp32b8x>, // TS Implementation
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>> // SS Implementation
|
||||
>;
|
||||
|
||||
using BlockTileB_N = decltype(cute::size<0, 0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
|
||||
using BlockTileB_K = decltype(cute::size<0, 1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
|
||||
|
||||
// Input transform kernel can not use TMA 2SM instructions.
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB, ElementB,
|
||||
BlockTileB_N, BlockTileB_K>());
|
||||
using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB,
|
||||
ElementBMma, BlockTileB_N, BlockTileB_K>());
|
||||
using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomB,
|
||||
SmemLayoutAtomBCompute>;
|
||||
using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementB>,
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementMma>>;
|
||||
|
||||
// Creating the stride of Transformed Input
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
|
||||
using LayoutScale = cutlass::gemm::TagToStrideA_t<GmemLayoutScaleTag>;
|
||||
|
||||
using VoidShapeScale
|
||||
= Shape<Shape<Int<128>, _1>, Shape<Int<64>, _1>, _1>; // Dummy Value to create a dummy ScaleConfig
|
||||
using VoidStrideScale = Stride<Stride<_0, _1>, Stride<_0, _1>, _1>;
|
||||
using VoidLayoutScale = Layout<VoidShapeScale, VoidStrideScale>;
|
||||
|
||||
using NonVoidLayoutScale = cute::conditional_t<cute::is_void_v<LayoutScale>, VoidLayoutScale, LayoutScale>;
|
||||
|
||||
using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{}));
|
||||
|
||||
// SmemCarveout
|
||||
static constexpr int SchedulerPipelineStageCount = 3;
|
||||
static constexpr bool IsArrayOfPointersGemm
|
||||
= (cute::is_base_of_v<KernelScheduleSm100PtrArrayFastFP32Gemm, KernelScheduleType>);
|
||||
|
||||
// CLCPipeline = PipelineCLCFetchAsync
|
||||
static constexpr auto CLCPipelineStorage
|
||||
= sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
||||
// CLC (scheduler) response
|
||||
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
|
||||
// CLC Throttle pipeline storage
|
||||
static constexpr auto CLCThrottlePipelineStorage
|
||||
= sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
|
||||
// Tmem dealloc
|
||||
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
||||
// Tmem ptr storage
|
||||
static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t);
|
||||
// Tensormap Storage
|
||||
static constexpr size_t TensorMapStorage
|
||||
= IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
|
||||
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
|
||||
static constexpr auto KernelSmemCarveout = static_cast<int>(CLCPipelineStorage + CLCResponseStorage
|
||||
+ CLCThrottlePipelineStorage + TmemDeallocStorage + TmemBasePtrsStorage + TensorMapStorage);
|
||||
|
||||
// Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations
|
||||
static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
static constexpr int ScaleGranularityK = get_ScaleGranularityK<LayoutScale>();
|
||||
|
||||
static constexpr auto stage_info
|
||||
= cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_weightonly<
|
||||
Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB,
|
||||
CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{});
|
||||
|
||||
static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info);
|
||||
static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info);
|
||||
static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info);
|
||||
|
||||
static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell");
|
||||
|
||||
using DispatchPolicy
|
||||
= cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput<Load2TransformPipelineStageCount,
|
||||
Transform2MmaPipelineStageCount, SchedulerPipelineStageCount, AccumulatorPipelineStageCount,
|
||||
ClusterShape_MNK>;
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMmaSm100WeightOnly<DispatchPolicy, TileShape_MNK,
|
||||
ElementPairA, StridePairA, ElementPairB, cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, TiledMma,
|
||||
GmemTiledCopyA, SmemLayoutAtomPairA, CopyAtomPairA, cute::identity, GmemTiledCopyB, SmemLayoutAtomPairB,
|
||||
CopyAtomPairB, cute::identity>;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType, class Enable = void>
|
||||
struct CollectiveBuilderSm100WeightOnly
|
||||
{
|
||||
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
|
||||
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
|
||||
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB>
|
||||
struct CollectiveMmaSm100WeightOnly
|
||||
{
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
@ -533,8 +533,8 @@ struct GemmFpAIntB
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ == 890)
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 1000)
|
||||
// Use SM80 implementation for GB10x, GB20x.
|
||||
#elif (__CUDA_ARCH__ >= 1200)
|
||||
// Use SM80 implementation for GB20x.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
|
||||
@ -87,7 +87,9 @@ public:
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
|
||||
&& Arch::kMinComputeCapability != 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
@ -102,7 +104,9 @@ public:
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
|
||||
&& Arch::kMinComputeCapability != 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
@ -116,6 +120,26 @@ public:
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -38,7 +38,13 @@ foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES})
|
||||
if(FILE_EXT STREQUAL ".py")
|
||||
# Read file content and replace module imports for Python files
|
||||
file(READ ${SOURCE_FILE} _content)
|
||||
string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content
|
||||
string(REPLACE "from . import _C" "import tensorrt_llm.deep_gemm_cpp_tllm"
|
||||
_content "${_content}")
|
||||
string(REPLACE ".._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
|
||||
"${_content}")
|
||||
string(REPLACE "._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
|
||||
"${_content}")
|
||||
string(REPLACE "_C." "tensorrt_llm.deep_gemm_cpp_tllm." _content
|
||||
"${_content}")
|
||||
|
||||
# Add adaptation header
|
||||
|
||||
@ -90,4 +90,5 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}
|
||||
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
|
||||
|
||||
add_subdirectory(cache_transmission/ucx_utils)
|
||||
add_subdirectory(cache_transmission/mooncake_utils)
|
||||
add_subdirectory(cache_transmission/nixl_utils)
|
||||
|
||||
@ -141,7 +141,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
|
||||
NotificationInfo notificationInfo{syncInfo};
|
||||
std::stringstream ss;
|
||||
NotificationInfo::serialize(notificationInfo, ss);
|
||||
status->wait();
|
||||
TransferState transferState = status->wait();
|
||||
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "AgentConnection::send failed");
|
||||
// TODO: there is a bug in request_with_notify https://github.com/ai-dynamo/nixl/pull/252
|
||||
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
|
||||
}
|
||||
@ -150,7 +151,7 @@ void AgentConnection::recv(DataContext const& ctx, void* data, size_t size) cons
|
||||
{
|
||||
|
||||
NotificationSyncInfo syncInfo{mAgentName, ctx};
|
||||
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo);
|
||||
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo, ctx.getTransferTerminate());
|
||||
}
|
||||
|
||||
void AgentConnection::sendRequestAndBufferInfo(batch_manager::RequestInfo& requestInfo,
|
||||
@ -230,13 +231,13 @@ void AgentConnection::sendReadySignal(DataContext const& ctx, bool isReady) cons
|
||||
bool AgentConnection::recvReadySignal(DataContext const& ctx) const
|
||||
{
|
||||
ReadySignalInfo readySignalInfo{mAgentName, ctx, false};
|
||||
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo);
|
||||
return true;
|
||||
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo, ctx.getTransferTerminate());
|
||||
return readySignalInfo.mIsReady;
|
||||
}
|
||||
|
||||
AgentConnectionManager::AgentConnectionManager(
|
||||
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
|
||||
CacheState cacheState)
|
||||
CacheState cacheState, std::string const& backendType)
|
||||
: mCacheState(std::move(cacheState))
|
||||
, mCacheTransBufferManagers(std::move(cacheTransBufferManagers))
|
||||
, mRegMemDescs(MemoryType::kVRAM, {})
|
||||
@ -246,8 +247,8 @@ AgentConnectionManager::AgentConnectionManager(
|
||||
|
||||
mAgentName = genUniqueAgentName();
|
||||
// Create Agent
|
||||
BaseAgentConfig config{mAgentName, true};
|
||||
m_Agent = makeTransferAgent("nixl", &config);
|
||||
BaseAgentConfig config{mAgentName, true, false, true, 1};
|
||||
m_Agent = makeTransferAgent(backendType, &config);
|
||||
TLLM_CHECK(!mCacheTransBufferManagers.empty());
|
||||
std::vector<MemoryDesc> memDescs;
|
||||
for (auto* cacheTransBufferManager : mCacheTransBufferManagers)
|
||||
@ -315,9 +316,10 @@ AgentConnectionManager::AgentConnectionManager(
|
||||
" ***** AgentConnectionManager::AgentConnectionManager mCommState: %s", mCommState.toString().c_str());
|
||||
}
|
||||
|
||||
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo)
|
||||
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(
|
||||
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag)
|
||||
{
|
||||
while (true)
|
||||
while (!terminateFlag.load())
|
||||
{
|
||||
if (!mIsRunning)
|
||||
{
|
||||
@ -490,16 +492,16 @@ int AgentConnectionManager::getDeviceId() const
|
||||
}
|
||||
|
||||
template <typename NotificationType>
|
||||
void AgentConnectionManager::waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo)
|
||||
void AgentConnectionManager::waitForNotification(
|
||||
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag)
|
||||
{
|
||||
while (true)
|
||||
while (!terminateFlag.load())
|
||||
{
|
||||
|
||||
if (!mIsRunning)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
updateUnhandledNotifications();
|
||||
std::scoped_lock lock(mNotificationMutex);
|
||||
auto it = mUnhandledNotifications.begin();
|
||||
@ -575,18 +577,20 @@ void AgentConnectionManager::waitForNotification(std::string const& remoteAgentN
|
||||
|
||||
// Explicit template instantiations
|
||||
template void AgentConnectionManager::waitForNotification<NotificationSyncInfo>(
|
||||
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo);
|
||||
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
|
||||
template void AgentConnectionManager::waitForNotification<ReadySignalInfo>(
|
||||
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo);
|
||||
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
|
||||
|
||||
void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo)
|
||||
void AgentConnectionManager::waitForSyncInfo(
|
||||
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag)
|
||||
{
|
||||
waitForNotification(remoteAgentName, syncInfo);
|
||||
waitForNotification(remoteAgentName, syncInfo, terminateFlag);
|
||||
}
|
||||
|
||||
void AgentConnectionManager::waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo)
|
||||
void AgentConnectionManager::waitForReadySignal(
|
||||
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag)
|
||||
{
|
||||
waitForNotification(remoteAgentName, readySignalInfo);
|
||||
waitForNotification(remoteAgentName, readySignalInfo, terminateFlag);
|
||||
}
|
||||
|
||||
std::string const& AgentConnectionManager::getAgentName() const
|
||||
|
||||
@ -277,12 +277,13 @@ class AgentConnectionManager : public ConnectionManager
|
||||
public:
|
||||
AgentConnectionManager(
|
||||
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
|
||||
CacheState cacheState);
|
||||
CacheState cacheState, std::string const& backendType);
|
||||
~AgentConnectionManager();
|
||||
AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override;
|
||||
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
|
||||
[[nodiscard]] CommState const& getCommState() const override;
|
||||
AgentConnection const* recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo);
|
||||
AgentConnection const* recvConnectionAndRequestInfo(
|
||||
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag);
|
||||
[[nodiscard]] std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> const&
|
||||
getCacheTransBufferManagers() const;
|
||||
void updateUnhandledNotifications();
|
||||
@ -293,9 +294,12 @@ public:
|
||||
[[nodiscard]] std::string const& getAgentName() const;
|
||||
|
||||
template <typename NotificationType>
|
||||
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
|
||||
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
|
||||
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
|
||||
void waitForNotification(
|
||||
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag);
|
||||
void waitForSyncInfo(
|
||||
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag);
|
||||
void waitForReadySignal(
|
||||
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag);
|
||||
[[nodiscard]] bool isRunning() const override;
|
||||
|
||||
private:
|
||||
|
||||
@ -107,9 +107,9 @@ TargetRanksInfo TargetRanksInfoForDP(
|
||||
auto const peerCPNum = peerParConfig.mContextParallelism;
|
||||
auto const selfCPNum = selfParConfig.mContextParallelism;
|
||||
|
||||
auto const selfTPRank = selfRank % selfTPNum;
|
||||
auto const selfCPRank = selfRank % selfCPNum;
|
||||
auto const selfTPRank = (selfRank % (selfTPNum * selfCPNum)) / selfCPNum;
|
||||
auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
|
||||
auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;
|
||||
|
||||
int peerPPRankStart = 0;
|
||||
int mDomainPPSize = 1;
|
||||
@ -205,13 +205,14 @@ TargetRanksInfo TargetRanksInfoForDP(
|
||||
}
|
||||
|
||||
std::vector<int> retRanks;
|
||||
for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
|
||||
for (int i = peerCPRankStart; i < peerCPRankEnd; i++)
|
||||
{
|
||||
for (int j = peerCPRankStart; j < peerCPRankEnd; j++)
|
||||
for (int j = peerTPRankStart; j < peerTPRankEnd; j++)
|
||||
{
|
||||
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
|
||||
{
|
||||
int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
|
||||
// Rank formula: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
|
||||
int irank = (k * peerTPNum * peerCPNum) + (j * peerCPNum) + i;
|
||||
retRanks.push_back(irank);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,45 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
|
||||
# Source Code License Agreement
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this material and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
|
||||
# prohibited.
|
||||
|
||||
# MOONCAKE is not supported on Rocky8 for now
|
||||
set(IS_ROCKY8 FALSE)
|
||||
if(EXISTS "/etc/redhat-release")
|
||||
set(IS_ROCKY8 TRUE)
|
||||
endif()
|
||||
|
||||
if(MOONCAKE_ROOT AND NOT IS_ROCKY8)
|
||||
find_library(TRANSFER_ENGINE_LIB transfer_engine ${MOONCAKE_ROOT}/lib)
|
||||
find_path(TRANSFER_ENGINE_INCLUDE_DIR transfer_engine_c.h
|
||||
${MOONCAKE_ROOT}/include)
|
||||
|
||||
message(STATUS "Find transfer engine results:")
|
||||
message(STATUS " TRANSFER_ENGINE_LIB = ${TRANSFER_ENGINE_LIB}")
|
||||
message(
|
||||
STATUS " TRANSFER_ENGINE_INCLUDE_DIR = ${TRANSFER_ENGINE_INCLUDE_DIR}")
|
||||
|
||||
if(TRANSFER_ENGINE_LIB AND TRANSFER_ENGINE_INCLUDE_DIR)
|
||||
set(MOONCAKE_WRAPPER_TARGET "tensorrt_llm_mooncake_wrapper")
|
||||
|
||||
add_library(${MOONCAKE_WRAPPER_TARGET} SHARED transferAgent.cpp)
|
||||
target_compile_options(${MOONCAKE_WRAPPER_TARGET} PRIVATE -Wno-error)
|
||||
|
||||
target_include_directories(${MOONCAKE_WRAPPER_TARGET}
|
||||
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
|
||||
|
||||
target_link_libraries(${MOONCAKE_WRAPPER_TARGET}
|
||||
PRIVATE ${TRANSFER_ENGINE_LIB} CUDA::cudart)
|
||||
|
||||
# Export variables to parent scope for transfer_agent_binding
|
||||
set(TRANSFER_ENGINE_INCLUDE_DIR
|
||||
${TRANSFER_ENGINE_INCLUDE_DIR}
|
||||
PARENT_SCOPE)
|
||||
endif()
|
||||
endif()
|
||||
@ -0,0 +1,612 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/executor/cache_transmission/mooncake_utils/transferAgent.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/ipUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <arpa/inet.h>
|
||||
#include <chrono>
|
||||
#include <dirent.h>
|
||||
#include <fcntl.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/file.h>
|
||||
#include <sys/stat.h>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace tensorrt_llm::executor::kv_cache
|
||||
{
|
||||
|
||||
MooncakeTransferStatus::MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount)
|
||||
: mEngine{engine}
|
||||
, mBatchId{batchId}
|
||||
, mRequestCount{requestCount}
|
||||
{
|
||||
TLLM_CHECK(mEngine);
|
||||
}
|
||||
|
||||
TransferState MooncakeTransferStatus::wait(int64_t timeout_ms) const
|
||||
{
|
||||
auto startTime = std::chrono::steady_clock::now();
|
||||
|
||||
while (true)
|
||||
{
|
||||
if (mBatchFreed)
|
||||
{
|
||||
return TransferState::kSUCCESS;
|
||||
}
|
||||
|
||||
bool has_failed = false;
|
||||
bool all_completed = true;
|
||||
|
||||
for (size_t index = 0; index < mRequestCount; ++index)
|
||||
{
|
||||
transfer_status_t status;
|
||||
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
|
||||
if (rc || status.status == STATUS_FAILED)
|
||||
{
|
||||
has_failed = true;
|
||||
if (rc)
|
||||
{
|
||||
TLLM_LOG_ERROR(
|
||||
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR(
|
||||
"Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
|
||||
}
|
||||
}
|
||||
else if (status.status != STATUS_COMPLETED)
|
||||
{
|
||||
all_completed = false;
|
||||
}
|
||||
}
|
||||
|
||||
// If any request failed, return failure
|
||||
if (has_failed)
|
||||
{
|
||||
return TransferState::kFAILURE;
|
||||
}
|
||||
|
||||
// If all requests completed successfully
|
||||
if (all_completed)
|
||||
{
|
||||
freeBatchID(mEngine, mBatchId);
|
||||
mBatchFreed = true;
|
||||
TLLM_LOG_DEBUG("Batch ID %lu freed in wait()", mBatchId);
|
||||
syncSegmentCache(mEngine);
|
||||
return TransferState::kSUCCESS;
|
||||
}
|
||||
|
||||
// If timeout_ms < 0, wait indefinitely
|
||||
if (timeout_ms < 0)
|
||||
{
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if timeout has elapsed
|
||||
auto elapsed
|
||||
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
|
||||
.count();
|
||||
if (elapsed >= timeout_ms)
|
||||
{
|
||||
return TransferState::kIN_PROGRESS;
|
||||
}
|
||||
|
||||
std::this_thread::yield();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] bool MooncakeTransferStatus::isCompleted() const
|
||||
{
|
||||
if (mBatchFreed)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool has_failed = false;
|
||||
for (size_t index = 0; index < mRequestCount; ++index)
|
||||
{
|
||||
transfer_status_t status;
|
||||
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
|
||||
if (rc || status.status == STATUS_FAILED)
|
||||
{
|
||||
has_failed = true;
|
||||
if (rc)
|
||||
{
|
||||
TLLM_LOG_ERROR(
|
||||
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
|
||||
}
|
||||
}
|
||||
else if (status.status == STATUS_PENDING || status.status == STATUS_WAITING)
|
||||
{
|
||||
TLLM_LOG_DEBUG("Transfer is pending for batch %lu, task %zu", mBatchId, index);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!has_failed)
|
||||
{
|
||||
// Each batchId has the batch size, and cannot process more requests
|
||||
// than the batch size. So, free the batch id here to workaround the issue
|
||||
// where the same batchId could be used to post multiple transfer.
|
||||
freeBatchID(mEngine, mBatchId);
|
||||
mBatchFreed = true;
|
||||
TLLM_LOG_DEBUG("Batch ID %lu freed, future calls will return true directly", mBatchId);
|
||||
}
|
||||
// Currently, we cannot distinguish between failed and completed from return value.
|
||||
TLLM_LOG_DEBUG("Transfer is completed for batch %lu", mBatchId);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string const MooncakeBase64Helper::STANDARD_CHARS
|
||||
= "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
|
||||
std::string MooncakeBase64Helper::encode(std::vector<uint8_t> const& data)
|
||||
{
|
||||
return encodeInternal(data, STANDARD_CHARS);
|
||||
}
|
||||
|
||||
std::string MooncakeBase64Helper::encode(std::string const& data)
|
||||
{
|
||||
std::vector<uint8_t> vec(data.begin(), data.end());
|
||||
return encode(vec);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> MooncakeBase64Helper::decode(std::string const& encoded)
|
||||
{
|
||||
return decodeInternal(encoded, STANDARD_CHARS);
|
||||
}
|
||||
|
||||
std::string MooncakeBase64Helper::decodeToString(std::string const& encoded)
|
||||
{
|
||||
auto vec = decode(encoded);
|
||||
return std::string(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
std::string MooncakeBase64Helper::encodeInternal(std::vector<uint8_t> const& data, std::string const& chars)
|
||||
{
|
||||
std::string encoded;
|
||||
size_t i = 0;
|
||||
size_t j = 0;
|
||||
std::array<uint8_t, 3> charArray3{};
|
||||
std::array<uint8_t, 4> charArray4{};
|
||||
size_t dataLen = data.size();
|
||||
uint8_t const* bytes = data.data();
|
||||
|
||||
while (dataLen--)
|
||||
{
|
||||
charArray3[i++] = *(bytes++);
|
||||
if (i == 3)
|
||||
{
|
||||
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
|
||||
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
|
||||
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
|
||||
charArray4[3] = charArray3[2] & 0x3f;
|
||||
|
||||
for (i = 0; i < 4; i++)
|
||||
{
|
||||
encoded += chars[charArray4[i]];
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (i > 0)
|
||||
{
|
||||
for (j = i; j < 3; j++)
|
||||
{
|
||||
charArray3[j] = '\0';
|
||||
}
|
||||
|
||||
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
|
||||
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
|
||||
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
|
||||
charArray4[3] = charArray3[2] & 0x3f;
|
||||
|
||||
for (j = 0; j < i + 1; j++)
|
||||
{
|
||||
encoded += chars[charArray4[j]];
|
||||
}
|
||||
|
||||
while (i++ < 3)
|
||||
{
|
||||
encoded += '=';
|
||||
}
|
||||
}
|
||||
|
||||
return encoded;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> MooncakeBase64Helper::decodeInternal(std::string const& encoded, std::string const& chars)
|
||||
{
|
||||
size_t encodedLen = encoded.size();
|
||||
size_t i = 0;
|
||||
size_t j = 0;
|
||||
size_t in_ = 0;
|
||||
std::array<uint8_t, 3> charArray3{};
|
||||
std::array<uint8_t, 4> charArray4{};
|
||||
std::vector<uint8_t> decoded;
|
||||
|
||||
std::string cleanEncoded;
|
||||
for (char c : encoded)
|
||||
{
|
||||
if (!isWhitespace(c))
|
||||
{
|
||||
cleanEncoded += c;
|
||||
}
|
||||
}
|
||||
|
||||
encodedLen = cleanEncoded.size();
|
||||
|
||||
while (encodedLen-- && cleanEncoded[in_] != '=' && isBase64(cleanEncoded[in_], chars))
|
||||
{
|
||||
charArray4[i++] = cleanEncoded[in_];
|
||||
in_++;
|
||||
if (i == 4)
|
||||
{
|
||||
for (i = 0; i < 4; i++)
|
||||
{
|
||||
charArray4[i] = chars.find(charArray4[i]);
|
||||
}
|
||||
|
||||
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
|
||||
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
|
||||
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
|
||||
|
||||
for (i = 0; i < 3; i++)
|
||||
{
|
||||
decoded.push_back(charArray3[i]);
|
||||
}
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (i > 0)
|
||||
{
|
||||
for (j = i; j < 4; j++)
|
||||
{
|
||||
charArray4[j] = 0;
|
||||
}
|
||||
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
charArray4[j] = chars.find(charArray4[j]);
|
||||
}
|
||||
|
||||
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
|
||||
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
|
||||
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
|
||||
|
||||
for (j = 0; j < i - 1; j++)
|
||||
{
|
||||
decoded.push_back(charArray3[j]);
|
||||
}
|
||||
}
|
||||
|
||||
return decoded;
|
||||
}
|
||||
|
||||
bool MooncakeBase64Helper::isBase64(uint8_t c, std::string const& chars)
|
||||
{
|
||||
return (isalnum(c) || (c == chars[62]) || (c == chars[63]));
|
||||
}
|
||||
|
||||
bool MooncakeBase64Helper::isWhitespace(uint8_t c)
|
||||
{
|
||||
return (c == ' ' || c == '\n' || c == '\r' || c == '\t');
|
||||
}
|
||||
|
||||
MooncakeTransferAgent::MooncakeTransferAgent(BaseAgentConfig const& config)
|
||||
{
|
||||
mLocalAgentName = config.mName;
|
||||
std::string segmentName = "127.0.0.1";
|
||||
|
||||
if (getenv("TLLM_MOONCAKE_IP_ADDR"))
|
||||
{
|
||||
segmentName = std::string(getenv("TLLM_MOONCAKE_IP_ADDR"));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ip = common::getLocalIp(common::getEnvMooncakeInterface(), mpi::MpiComm::session().getRank());
|
||||
if (!ip.empty())
|
||||
segmentName = ip;
|
||||
}
|
||||
|
||||
mEngine = createTransferEngine("P2PHANDSHAKE", segmentName.c_str(), "", 0, true);
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::registerMemory(RegisterDescs const& descs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::registerMemory");
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
for (auto const& desc : descs.getDescs())
|
||||
{
|
||||
auto it = mMemRegInfo.find(desc.getAddr());
|
||||
if (it != mMemRegInfo.end())
|
||||
{
|
||||
it->second->addRef();
|
||||
continue;
|
||||
}
|
||||
|
||||
int err = registerLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()), desc.getLen(), "*", 1);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(err == 0, "registerLocalMemory failed, addr: %p, len: %lu",
|
||||
reinterpret_cast<void*>(desc.getAddr()), desc.getLen());
|
||||
|
||||
auto mooncakeDesc = std::make_shared<MooncakeMemoryDesc>(desc);
|
||||
mMemRegInfo[desc.getAddr()] = std::move(mooncakeDesc);
|
||||
}
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::deregisterMemory(RegisterDescs const& descs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::deregisterMemory");
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
for (auto const& desc : descs.getDescs())
|
||||
{
|
||||
auto it = mMemRegInfo.find(desc.getAddr());
|
||||
if (it != mMemRegInfo.end())
|
||||
{
|
||||
auto const& mooncakeDesc = it->second;
|
||||
mooncakeDesc->releaseRef();
|
||||
if (mooncakeDesc->getRefCount())
|
||||
continue;
|
||||
|
||||
int err = unregisterLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
err == 0, "unregisterLocalMemory failed, addr: %p", reinterpret_cast<void*>(desc.getAddr()));
|
||||
|
||||
mMemRegInfo.erase(desc.getAddr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::loadRemoteAgent");
|
||||
|
||||
// Do the same thing as loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
|
||||
loadRemoteAgent(name, std::move(agentDesc.getBackendAgentDesc()));
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
|
||||
{
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
"MooncakeTransferAgent::loadRemoteAgent loadRemoteAgent to %s remoteagent name: %s", connectionInfo.c_str(),
|
||||
name.c_str());
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto segmentId = openSegment(mEngine, connectionInfo.c_str());
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
segmentId >= 0, "loadRemoteAgent openSegment failed, connectionInfo: %s", connectionInfo.c_str());
|
||||
|
||||
mConnectedAgents[name].segmentId = segmentId;
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::invalidateRemoteAgent(std::string const& name)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::invalidateRemoteAgent");
|
||||
}
|
||||
|
||||
AgentDesc MooncakeTransferAgent::getLocalAgentDesc()
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalAgentDesc");
|
||||
|
||||
// Using connection info as agent desc
|
||||
static size_t const kBufLen = 64;
|
||||
char connectionInfo[kBufLen];
|
||||
|
||||
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalIpAndPort failed");
|
||||
|
||||
return AgentDesc{std::string(connectionInfo)};
|
||||
}
|
||||
|
||||
ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalConnectionInfo");
|
||||
|
||||
static size_t const kBufLen = 64;
|
||||
char connectionInfo[kBufLen];
|
||||
|
||||
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalConnectionInfo failed");
|
||||
|
||||
return std::string(connectionInfo);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::unique_ptr<TransferStatus> MooncakeTransferAgent::submitTransferRequests(
|
||||
TransferRequest const& request)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::submitTransferRequests");
|
||||
|
||||
bool hasNotif = false;
|
||||
std::string syncMessage;
|
||||
|
||||
if (request.getSyncMessage().has_value())
|
||||
{
|
||||
hasNotif = true;
|
||||
syncMessage = request.getSyncMessage().value();
|
||||
}
|
||||
|
||||
static size_t const kMaxRequestCount = 1024;
|
||||
uint64_t batchId = allocateBatchID(mEngine, kMaxRequestCount);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchId != INVALID_BATCH, "allocateBatchID failed");
|
||||
|
||||
int segmentId;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
std::string remoteName = request.getRemoteName();
|
||||
|
||||
auto it = mConnectedAgents.find(remoteName);
|
||||
if (it == mConnectedAgents.end())
|
||||
{
|
||||
std::string error = "Remote agent " + remoteName + "not found";
|
||||
TLLM_THROW(error);
|
||||
}
|
||||
|
||||
auto const& agentInfo = it->second;
|
||||
segmentId = agentInfo.segmentId;
|
||||
}
|
||||
|
||||
auto localDescs = request.getSrcDescs().getDescs();
|
||||
auto remoteDescs = request.getDstDescs().getDescs();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(localDescs.size() == remoteDescs.size(), "Number of local and remote memory must match");
|
||||
|
||||
size_t requestCount = localDescs.size();
|
||||
std::vector<transfer_request_t> transferRequests(requestCount);
|
||||
|
||||
for (size_t index = 0; index < requestCount; ++index)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
localDescs[index].getLen() == remoteDescs[index].getLen(), "Length of local and remote memory must match");
|
||||
|
||||
transferRequests[index].opcode = (request.getOp() == TransferOp::kREAD) ? OPCODE_READ : OPCODE_WRITE;
|
||||
transferRequests[index].source = reinterpret_cast<void*>(localDescs[index].getAddr());
|
||||
transferRequests[index].target_offset = remoteDescs[index].getAddr();
|
||||
transferRequests[index].length = localDescs[index].getLen();
|
||||
transferRequests[index].target_id = segmentId;
|
||||
}
|
||||
|
||||
int rc = 0;
|
||||
if (hasNotif)
|
||||
{
|
||||
notify_msg_t notifyMsg;
|
||||
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
|
||||
notifyMsg.msg = const_cast<char*>(syncMessage.c_str());
|
||||
rc = submitTransferWithNotify(mEngine, batchId, transferRequests.data(), requestCount, notifyMsg);
|
||||
}
|
||||
else
|
||||
{
|
||||
rc = submitTransfer(mEngine, batchId, transferRequests.data(), requestCount);
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(rc == 0, "submitTransfer failed with status: %d", rc);
|
||||
|
||||
return std::make_unique<MooncakeTransferStatus>(mEngine, batchId, requestCount);
|
||||
}
|
||||
|
||||
void MooncakeTransferAgent::notifySyncMessage(std::string const& name, SyncMessage const& syncMessage)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage");
|
||||
int segmentId;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto it = mConnectedAgents.find(name);
|
||||
|
||||
if (it == mConnectedAgents.end())
|
||||
{
|
||||
TLLM_LOG_WARNING("Remote agent %s not found", name.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
auto const& agentInfo = it->second;
|
||||
segmentId = agentInfo.segmentId;
|
||||
}
|
||||
|
||||
notify_msg_t notifyMsg;
|
||||
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
|
||||
std::string encoded = MooncakeBase64Helper::encode(syncMessage);
|
||||
notifyMsg.msg = const_cast<char*>(encoded.c_str());
|
||||
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage notifyMsg.name: %s, notifyMsg.msg: %s", notifyMsg.name,
|
||||
notifyMsg.msg);
|
||||
|
||||
int ret = genNotifyInEngine(mEngine, segmentId, notifyMsg);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(ret == 0, "genNotifyInEngine failed with status: %d", ret);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> MooncakeTransferAgent::getNotifiedSyncMessages()
|
||||
{
|
||||
std::unordered_map<std::string, std::vector<SyncMessage>> notifs;
|
||||
int size = 0;
|
||||
|
||||
notify_msg_t* notifyMsgs = getNotifsFromEngine(mEngine, &size);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(size >= 0, "getNotifsFromEngine returned negative size: %d", size);
|
||||
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
if (notifyMsgs[i].msg == nullptr)
|
||||
{
|
||||
TLLM_LOG_WARNING("Message pointer is null for: %s", notifyMsgs[i].name);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string decoded = MooncakeBase64Helper::decodeToString(notifyMsgs[i].msg);
|
||||
notifs[notifyMsgs[i].name].emplace_back(std::move(decoded));
|
||||
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::getNotifiedSyncMessages getNotifsFromEngine: %s, %s", notifyMsgs[i].name,
|
||||
notifyMsgs[i].msg);
|
||||
}
|
||||
|
||||
freeNotifsMsgBuf(notifyMsgs, size);
|
||||
return notifs;
|
||||
}
|
||||
|
||||
bool MooncakeTransferAgent::checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::checkRemoteDescs");
|
||||
return true;
|
||||
}
|
||||
|
||||
MooncakeTransferAgent::~MooncakeTransferAgent()
|
||||
{
|
||||
destroyTransferEngine(mEngine);
|
||||
TLLM_LOG_DEBUG("MooncakeTransferAgent::~MooncakeTransferAgent");
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
|
||||
#endif
|
||||
|
||||
extern "C"
|
||||
{
|
||||
std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config)
|
||||
{
|
||||
TLLM_CHECK(config);
|
||||
return std::make_unique<MooncakeTransferAgent>(*config);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
@ -0,0 +1,165 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
#include "transfer_engine_c.h"
|
||||
|
||||
namespace tensorrt_llm::executor::kv_cache
|
||||
{
|
||||
|
||||
class MooncakeTransferStatus final : public TransferStatus
|
||||
{
|
||||
public:
|
||||
MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount);
|
||||
|
||||
[[nodiscard]] bool isCompleted() const override;
|
||||
|
||||
TransferState wait(int64_t timeout_ms = -1) const override;
|
||||
|
||||
private:
|
||||
transfer_engine_t mEngine;
|
||||
uint64_t mBatchId;
|
||||
size_t mRequestCount;
|
||||
mutable bool mBatchFreed = false;
|
||||
};
|
||||
|
||||
class MooncakeMemoryDesc
|
||||
{
|
||||
public:
|
||||
MooncakeMemoryDesc(MemoryDesc desc)
|
||||
: mDesc{std::move(desc)}
|
||||
, mRefCnt{0}
|
||||
{
|
||||
}
|
||||
|
||||
MooncakeMemoryDesc(MooncakeMemoryDesc const& other)
|
||||
: mDesc{other.mDesc}
|
||||
, mRefCnt{0}
|
||||
{
|
||||
}
|
||||
|
||||
MooncakeMemoryDesc& operator=(MooncakeMemoryDesc const&) = delete;
|
||||
|
||||
~MooncakeMemoryDesc() = default;
|
||||
|
||||
void addRef() noexcept
|
||||
{
|
||||
++mRefCnt;
|
||||
}
|
||||
|
||||
int releaseRef() noexcept
|
||||
{
|
||||
return --mRefCnt;
|
||||
}
|
||||
|
||||
int getRefCount() const noexcept
|
||||
{
|
||||
return mRefCnt;
|
||||
}
|
||||
|
||||
MemoryDesc const& getDesc() const noexcept
|
||||
{
|
||||
return mDesc;
|
||||
}
|
||||
|
||||
private:
|
||||
MemoryDesc mDesc;
|
||||
int mRefCnt;
|
||||
};
|
||||
|
||||
class MooncakeBase64Helper
|
||||
{
|
||||
public:
|
||||
static std::string encode(std::vector<uint8_t> const& data);
|
||||
static std::string encode(std::string const& data);
|
||||
|
||||
static std::vector<uint8_t> decode(std::string const& encoded);
|
||||
static std::string decodeToString(std::string const& encoded);
|
||||
|
||||
private:
|
||||
static const std::string STANDARD_CHARS;
|
||||
|
||||
static std::string encodeInternal(std::vector<uint8_t> const& data, std::string const& chars);
|
||||
static std::vector<uint8_t> decodeInternal(std::string const& encoded, std::string const& chars);
|
||||
|
||||
static inline bool isBase64(uint8_t c, std::string const& chars);
|
||||
static inline bool isWhitespace(uint8_t c);
|
||||
};
|
||||
|
||||
class MooncakeTransferAgent final : public BaseTransferAgent
|
||||
{
|
||||
public:
|
||||
MooncakeTransferAgent(BaseAgentConfig const& config);
|
||||
~MooncakeTransferAgent();
|
||||
|
||||
void registerMemory(RegisterDescs const& descs) override;
|
||||
|
||||
void deregisterMemory(RegisterDescs const& descs) override;
|
||||
|
||||
void loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc) override;
|
||||
|
||||
void loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) override;
|
||||
|
||||
void invalidateRemoteAgent(std::string const& name) override;
|
||||
|
||||
AgentDesc getLocalAgentDesc() override;
|
||||
|
||||
ConnectionInfoType getLocalConnectionInfo() override;
|
||||
|
||||
[[nodiscard]] std::unique_ptr<TransferStatus> submitTransferRequests(TransferRequest const& request) override;
|
||||
|
||||
void notifySyncMessage(std::string const& name, SyncMessage const& syncMessage) override;
|
||||
|
||||
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> getNotifiedSyncMessages() override;
|
||||
|
||||
bool checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs) override;
|
||||
|
||||
private:
|
||||
struct AgentInfo
|
||||
{
|
||||
int segmentId;
|
||||
};
|
||||
|
||||
mutable std::mutex mMutex;
|
||||
transfer_engine_t mEngine;
|
||||
std::unordered_map<uintptr_t, std::shared_ptr<MooncakeMemoryDesc>> mMemRegInfo;
|
||||
std::unordered_map<std::string, AgentInfo> mConnectedAgents;
|
||||
std::string mLocalAgentName;
|
||||
};
|
||||
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
|
||||
#endif
|
||||
|
||||
extern "C"
|
||||
{
|
||||
[[nodiscard]] std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config);
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt_llm::executor::kv_cache
|
||||
@ -13,6 +13,9 @@
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
|
||||
# ============================================================================
|
||||
# NIXL Wrapper Library
|
||||
# ============================================================================
|
||||
if(NIXL_ROOT)
|
||||
find_package(NIXL REQUIRED)
|
||||
# Check if all required packages were found
|
||||
@ -30,6 +33,8 @@ if(NIXL_ROOT)
|
||||
|
||||
# Add include directories
|
||||
target_include_directories(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
|
||||
target_include_directories(${NIXL_WRAPPER_TARGET}
|
||||
PRIVATE ${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
# Link against all NIXL libraries
|
||||
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
|
||||
@ -37,4 +42,85 @@ if(NIXL_ROOT)
|
||||
# Link against CUDA
|
||||
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE CUDA::cudart)
|
||||
|
||||
set(NIXL_ENABLED TRUE)
|
||||
else()
|
||||
set(NIXL_ENABLED FALSE)
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# Check if Mooncake wrapper is available (built in mooncake_utils)
|
||||
# ============================================================================
|
||||
if(MOONCAKE_ROOT AND TARGET tensorrt_llm_mooncake_wrapper)
|
||||
set(MOONCAKE_ENABLED TRUE)
|
||||
else()
|
||||
set(MOONCAKE_ENABLED FALSE)
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# TensorRT-LLM Transfer Agent Binding Python Module Build if either NIXL or
|
||||
# Mooncake is enabled
|
||||
# ============================================================================
|
||||
if(NIXL_ENABLED OR MOONCAKE_ENABLED)
|
||||
set(TRANSFER_AGENT_BINDING_TARGET "tensorrt_llm_transfer_agent_binding")
|
||||
|
||||
# Collect binding source files
|
||||
set(AGENT_BINDING_SOURCES "")
|
||||
if(BINDING_TYPE STREQUAL "pybind")
|
||||
list(APPEND AGENT_BINDING_SOURCES agentBindingsPybind.cpp)
|
||||
else()
|
||||
list(APPEND AGENT_BINDING_SOURCES agentBindingsNanobind.cpp)
|
||||
endif()
|
||||
|
||||
if(BINDING_TYPE STREQUAL "pybind")
|
||||
# Use pybind11 (already fetched via FetchContent)
|
||||
pybind11_add_module(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
${AGENT_BINDING_SOURCES})
|
||||
message(STATUS "Building tensorrt_llm_transfer_agent_binding with pybind11")
|
||||
else()
|
||||
# Default to nanobind (already fetched via FetchContent)
|
||||
nanobind_add_module(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
${AGENT_BINDING_SOURCES})
|
||||
message(STATUS "Building tensorrt_llm_transfer_agent_binding with nanobind")
|
||||
endif()
|
||||
|
||||
target_compile_options(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE -Wno-error)
|
||||
|
||||
# Add common include directories
|
||||
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
# Conditionally add NIXL support
|
||||
if(NIXL_ENABLED)
|
||||
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ENABLE_NIXL)
|
||||
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE NIXL::nixl)
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${NIXL_WRAPPER_TARGET})
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE NIXL::nixl)
|
||||
message(STATUS "Transfer agent binding: NIXL support enabled")
|
||||
endif()
|
||||
|
||||
# Conditionally add Mooncake support
|
||||
if(MOONCAKE_ENABLED)
|
||||
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ENABLE_MOONCAKE)
|
||||
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE tensorrt_llm_mooncake_wrapper)
|
||||
message(STATUS "Transfer agent binding: Mooncake support enabled")
|
||||
endif()
|
||||
|
||||
# Common dependencies
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE CUDA::cudart)
|
||||
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PRIVATE ${SHARED_TARGET})
|
||||
|
||||
# Set RPATH for the module to find wrapper libraries
|
||||
set_target_properties(
|
||||
${TRANSFER_AGENT_BINDING_TARGET}
|
||||
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl"
|
||||
INSTALL_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl")
|
||||
|
||||
endif()
|
||||
|
||||
@ -0,0 +1,239 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
#include "transferAgent.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
#include "../mooncake_utils/transferAgent.h"
|
||||
#endif
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
namespace kvc = tensorrt_llm::executor::kv_cache;
|
||||
|
||||
NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
{
|
||||
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (nanobind)";
|
||||
|
||||
// MemoryType enum
|
||||
nb::enum_<kvc::MemoryType>(m, "MemoryType")
|
||||
.value("DRAM", kvc::MemoryType::kDRAM)
|
||||
.value("VRAM", kvc::MemoryType::kVRAM)
|
||||
.value("BLK", kvc::MemoryType::kBLK)
|
||||
.value("OBJ", kvc::MemoryType::kOBJ)
|
||||
.value("FILE", kvc::MemoryType::kFILE);
|
||||
|
||||
// TransferOp enum
|
||||
nb::enum_<kvc::TransferOp>(m, "TransferOp")
|
||||
.value("READ", kvc::TransferOp::kREAD)
|
||||
.value("WRITE", kvc::TransferOp::kWRITE);
|
||||
|
||||
// TransferState enum
|
||||
nb::enum_<kvc::TransferState>(m, "TransferState")
|
||||
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
|
||||
.value("SUCCESS", kvc::TransferState::kSUCCESS)
|
||||
.value("FAILURE", kvc::TransferState::kFAILURE);
|
||||
|
||||
// MemoryDesc class
|
||||
nb::class_<kvc::MemoryDesc>(m, "MemoryDesc")
|
||||
.def(nb::init<uintptr_t, size_t, uint32_t>(), nb::arg("addr"), nb::arg("len"), nb::arg("device_id"))
|
||||
.def_prop_ro("addr", &kvc::MemoryDesc::getAddr)
|
||||
.def_prop_ro("len", &kvc::MemoryDesc::getLen)
|
||||
.def_prop_ro("device_id", &kvc::MemoryDesc::getDeviceId);
|
||||
|
||||
// MemoryDescs class
|
||||
nb::class_<kvc::MemoryDescs>(m, "MemoryDescs")
|
||||
.def(nb::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), nb::arg("type"), nb::arg("descs"))
|
||||
.def_prop_ro("type", &kvc::MemoryDescs::getType)
|
||||
.def_prop_ro("descs", &kvc::MemoryDescs::getDescs);
|
||||
|
||||
// AgentDesc class
|
||||
nb::class_<kvc::AgentDesc>(m, "AgentDesc")
|
||||
.def(
|
||||
"__init__",
|
||||
[](kvc::AgentDesc* self, nb::bytes data)
|
||||
{
|
||||
std::string str(data.c_str(), data.size());
|
||||
new (self) kvc::AgentDesc{std::move(str)};
|
||||
},
|
||||
nb::arg("backend_agent_desc"))
|
||||
.def(nb::init<std::string>(), nb::arg("backend_agent_desc"))
|
||||
.def_prop_ro("backend_agent_desc",
|
||||
[](kvc::AgentDesc const& self)
|
||||
{
|
||||
auto const& desc = self.getBackendAgentDesc();
|
||||
return nb::bytes(desc.data(), desc.size());
|
||||
});
|
||||
|
||||
// TransferRequest class
|
||||
nb::class_<kvc::TransferRequest>(m, "TransferRequest")
|
||||
.def(nb::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
|
||||
std::optional<kvc::SyncMessage>>(),
|
||||
nb::arg("op"), nb::arg("src_descs"), nb::arg("dst_descs"), nb::arg("remote_name"),
|
||||
nb::arg("sync_message") = std::nullopt)
|
||||
.def_prop_ro("op", &kvc::TransferRequest::getOp)
|
||||
.def_prop_ro("src_descs", &kvc::TransferRequest::getSrcDescs)
|
||||
.def_prop_ro("dst_descs", &kvc::TransferRequest::getDstDescs)
|
||||
.def_prop_ro("remote_name", &kvc::TransferRequest::getRemoteName)
|
||||
.def_prop_ro("sync_message", &kvc::TransferRequest::getSyncMessage);
|
||||
|
||||
// TransferStatus base class
|
||||
nb::class_<kvc::TransferStatus>(m, "TransferStatus")
|
||||
.def("is_completed", &kvc::TransferStatus::isCompleted)
|
||||
.def("wait", &kvc::TransferStatus::wait, nb::arg("timeout_ms") = -1);
|
||||
|
||||
// BaseAgentConfig struct
|
||||
nb::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
|
||||
.def(nb::init<>())
|
||||
.def(
|
||||
"__init__",
|
||||
[](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread,
|
||||
bool use_listen_thread, unsigned int num_workers) {
|
||||
new (self) kvc::BaseAgentConfig{
|
||||
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
|
||||
},
|
||||
nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false,
|
||||
nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1)
|
||||
.def_rw("name", &kvc::BaseAgentConfig::mName)
|
||||
.def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
|
||||
.def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread)
|
||||
.def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
|
||||
.def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers);
|
||||
|
||||
// BaseTransferAgent class (abstract base)
|
||||
nb::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
|
||||
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, nb::arg("descs"))
|
||||
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, nb::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
|
||||
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, nb::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
nb::arg("request"), nb::rv_policy::take_ownership)
|
||||
.def(
|
||||
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
|
||||
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
|
||||
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
// NixlTransferStatus class - release GIL for blocking operations
|
||||
nb::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
|
||||
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("wait", &kvc::NixlTransferStatus::wait, nb::arg("timeout_ms") = -1,
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
// NixlTransferAgent class
|
||||
nb::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
|
||||
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
|
||||
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, nb::arg("descs"))
|
||||
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, nb::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, nb::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
// MooncakeTransferStatus class - release GIL for blocking operations
|
||||
nb::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
|
||||
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("wait", &kvc::MooncakeTransferStatus::wait, nb::arg("timeout_ms") = -1,
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
// MooncakeTransferAgent class
|
||||
nb::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
|
||||
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
|
||||
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, nb::arg("descs"))
|
||||
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, nb::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
nb::arg("name"), nb::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, nb::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, nb::arg("name"),
|
||||
nb::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, nb::arg("name"),
|
||||
nb::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
// Factory function to create transfer agent by backend name (uses dynamic loading)
|
||||
m.def(
|
||||
"make_transfer_agent",
|
||||
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
|
||||
{ return kvc::makeTransferAgent(backend, &config).release(); },
|
||||
nb::arg("backend"), nb::arg("config"), nb::rv_policy::take_ownership,
|
||||
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
|
||||
|
||||
// Expose which backends are available
|
||||
#ifdef ENABLE_NIXL
|
||||
m.attr("NIXL_ENABLED") = true;
|
||||
#else
|
||||
m.attr("NIXL_ENABLED") = false;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
m.attr("MOONCAKE_ENABLED") = true;
|
||||
#else
|
||||
m.attr("MOONCAKE_ENABLED") = false;
|
||||
#endif
|
||||
}
|
||||
@ -0,0 +1,234 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/executor/transferAgent.h"
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
#include "transferAgent.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
#include "../mooncake_utils/transferAgent.h"
|
||||
#endif
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace kvc = tensorrt_llm::executor::kv_cache;
|
||||
|
||||
PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
|
||||
{
|
||||
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (pybind11)";
|
||||
|
||||
// MemoryType enum
|
||||
py::enum_<kvc::MemoryType>(m, "MemoryType")
|
||||
.value("DRAM", kvc::MemoryType::kDRAM)
|
||||
.value("VRAM", kvc::MemoryType::kVRAM)
|
||||
.value("BLK", kvc::MemoryType::kBLK)
|
||||
.value("OBJ", kvc::MemoryType::kOBJ)
|
||||
.value("FILE", kvc::MemoryType::kFILE);
|
||||
|
||||
// TransferOp enum
|
||||
py::enum_<kvc::TransferOp>(m, "TransferOp")
|
||||
.value("READ", kvc::TransferOp::kREAD)
|
||||
.value("WRITE", kvc::TransferOp::kWRITE);
|
||||
|
||||
// TransferState enum
|
||||
py::enum_<kvc::TransferState>(m, "TransferState")
|
||||
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
|
||||
.value("SUCCESS", kvc::TransferState::kSUCCESS)
|
||||
.value("FAILURE", kvc::TransferState::kFAILURE);
|
||||
|
||||
// MemoryDesc class
|
||||
py::class_<kvc::MemoryDesc>(m, "MemoryDesc")
|
||||
.def(py::init<uintptr_t, size_t, uint32_t>(), py::arg("addr"), py::arg("len"), py::arg("device_id"))
|
||||
.def_property_readonly("addr", &kvc::MemoryDesc::getAddr)
|
||||
.def_property_readonly("len", &kvc::MemoryDesc::getLen)
|
||||
.def_property_readonly("device_id", &kvc::MemoryDesc::getDeviceId);
|
||||
|
||||
// MemoryDescs class
|
||||
py::class_<kvc::MemoryDescs>(m, "MemoryDescs")
|
||||
.def(py::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), py::arg("type"), py::arg("descs"))
|
||||
.def_property_readonly("type", &kvc::MemoryDescs::getType)
|
||||
.def_property_readonly("descs", &kvc::MemoryDescs::getDescs);
|
||||
|
||||
// AgentDesc class
|
||||
py::class_<kvc::AgentDesc>(m, "AgentDesc")
|
||||
.def(py::init(
|
||||
[](py::bytes data)
|
||||
{
|
||||
std::string str(PyBytes_AsString(data.ptr()), PyBytes_Size(data.ptr()));
|
||||
return kvc::AgentDesc{std::move(str)};
|
||||
}),
|
||||
py::arg("backend_agent_desc"))
|
||||
.def(py::init<std::string>(), py::arg("backend_agent_desc"))
|
||||
.def_property_readonly("backend_agent_desc",
|
||||
[](kvc::AgentDesc const& self)
|
||||
{
|
||||
auto const& desc = self.getBackendAgentDesc();
|
||||
return py::bytes(desc.data(), desc.size());
|
||||
});
|
||||
|
||||
// TransferRequest class
|
||||
py::class_<kvc::TransferRequest>(m, "TransferRequest")
|
||||
.def(py::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
|
||||
std::optional<kvc::SyncMessage>>(),
|
||||
py::arg("op"), py::arg("src_descs"), py::arg("dst_descs"), py::arg("remote_name"),
|
||||
py::arg("sync_message") = std::nullopt)
|
||||
.def_property_readonly("op", &kvc::TransferRequest::getOp)
|
||||
.def_property_readonly("src_descs", &kvc::TransferRequest::getSrcDescs)
|
||||
.def_property_readonly("dst_descs", &kvc::TransferRequest::getDstDescs)
|
||||
.def_property_readonly("remote_name", &kvc::TransferRequest::getRemoteName)
|
||||
.def_property_readonly("sync_message", &kvc::TransferRequest::getSyncMessage);
|
||||
|
||||
// TransferStatus base class
|
||||
py::class_<kvc::TransferStatus>(m, "TransferStatus")
|
||||
.def("is_completed", &kvc::TransferStatus::isCompleted)
|
||||
.def("wait", &kvc::TransferStatus::wait, py::arg("timeout_ms") = -1);
|
||||
|
||||
// BaseAgentConfig struct
|
||||
py::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init(
|
||||
[](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread,
|
||||
unsigned int num_workers) {
|
||||
return kvc::BaseAgentConfig{
|
||||
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
|
||||
}),
|
||||
py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false,
|
||||
py::arg("use_listen_thread") = false, py::arg("num_workers") = 1)
|
||||
.def_readwrite("name", &kvc::BaseAgentConfig::mName)
|
||||
.def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
|
||||
.def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread)
|
||||
.def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
|
||||
.def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers);
|
||||
|
||||
// BaseTransferAgent class (abstract base)
|
||||
py::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
|
||||
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, py::arg("descs"))
|
||||
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, py::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::BaseTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
|
||||
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, py::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
py::arg("request"), py::return_value_policy::take_ownership)
|
||||
.def(
|
||||
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
|
||||
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
|
||||
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
|
||||
|
||||
#ifdef ENABLE_NIXL
|
||||
// NixlTransferStatus class - release GIL for blocking operations
|
||||
py::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
|
||||
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
|
||||
.def("wait", &kvc::NixlTransferStatus::wait, py::arg("timeout_ms") = -1,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
// NixlTransferAgent class
|
||||
py::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
|
||||
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
|
||||
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, py::arg("descs"))
|
||||
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, py::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::NixlTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, py::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
// MooncakeTransferStatus class - release GIL for blocking operations
|
||||
py::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
|
||||
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
|
||||
.def("wait", &kvc::MooncakeTransferStatus::wait, py::arg("timeout_ms") = -1,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
// MooncakeTransferAgent class
|
||||
py::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
|
||||
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
|
||||
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, py::arg("descs"))
|
||||
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, py::arg("descs"))
|
||||
.def("load_remote_agent",
|
||||
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("agent_desc"))
|
||||
.def("load_remote_agent_by_connection",
|
||||
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
|
||||
&kvc::MooncakeTransferAgent::loadRemoteAgent),
|
||||
py::arg("name"), py::arg("connection_info"))
|
||||
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
|
||||
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
|
||||
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, py::arg("name"))
|
||||
.def(
|
||||
"submit_transfer_requests",
|
||||
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
|
||||
{ return self.submitTransferRequests(request).release(); },
|
||||
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
|
||||
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, py::arg("name"),
|
||||
py::arg("sync_message"))
|
||||
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
|
||||
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, py::arg("name"),
|
||||
py::arg("memory_descs"));
|
||||
#endif
|
||||
|
||||
// Factory function to create transfer agent by backend name (uses dynamic loading)
|
||||
m.def(
|
||||
"make_transfer_agent",
|
||||
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
|
||||
{ return kvc::makeTransferAgent(backend, &config).release(); },
|
||||
py::arg("backend"), py::arg("config"), py::return_value_policy::take_ownership,
|
||||
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
|
||||
|
||||
// Expose which backends are available
|
||||
#ifdef ENABLE_NIXL
|
||||
m.attr("NIXL_ENABLED") = true;
|
||||
#else
|
||||
m.attr("NIXL_ENABLED") = false;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MOONCAKE
|
||||
m.attr("MOONCAKE_ENABLED") = true;
|
||||
#else
|
||||
m.attr("MOONCAKE_ENABLED") = false;
|
||||
#endif
|
||||
}
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <chrono>
|
||||
#include <dirent.h>
|
||||
#include <fcntl.h>
|
||||
#include <ifaddrs.h>
|
||||
@ -31,6 +32,7 @@
|
||||
#include <set>
|
||||
#include <sys/file.h>
|
||||
#include <sys/stat.h>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
@ -318,10 +320,40 @@ NixlTransferStatus::NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle)
|
||||
TLLM_CHECK(mHandle);
|
||||
}
|
||||
|
||||
void NixlTransferStatus::wait() const
|
||||
TransferState NixlTransferStatus::wait(int64_t timeout_ms) const
|
||||
{
|
||||
while (!isCompleted())
|
||||
;
|
||||
auto startTime = std::chrono::steady_clock::now();
|
||||
|
||||
while (true)
|
||||
{
|
||||
auto status = mRawAgent->getXferStatus(mHandle);
|
||||
if (status == NIXL_SUCCESS)
|
||||
{
|
||||
return TransferState::kSUCCESS;
|
||||
}
|
||||
else if (status != NIXL_IN_PROG)
|
||||
{
|
||||
return TransferState::kFAILURE;
|
||||
}
|
||||
|
||||
// If timeout_ms < 0, wait indefinitely until status is not NIXL_IN_PROG
|
||||
if (timeout_ms < 0)
|
||||
{
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if timeout has elapsed
|
||||
auto elapsed
|
||||
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
|
||||
.count();
|
||||
if (elapsed >= timeout_ms)
|
||||
{
|
||||
return TransferState::kIN_PROGRESS;
|
||||
}
|
||||
|
||||
std::this_thread::yield();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] bool NixlTransferStatus::isCompleted() const
|
||||
@ -333,6 +365,7 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
|
||||
: mName{config.mName}
|
||||
{
|
||||
nixl_status_t status;
|
||||
if (config.useListenThread)
|
||||
{
|
||||
FileLock lock("/tmp/trtllm_nixl_port.lock");
|
||||
if (!lock.lock())
|
||||
@ -341,10 +374,18 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
|
||||
}
|
||||
auto envPort = common::getEnvNixlPort();
|
||||
uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort();
|
||||
nixlAgentConfig nixlConfig{config.useProgThread, true, port};
|
||||
nixlAgentConfig nixlConfig{
|
||||
config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
|
||||
mAddress = getAvailableIP() + ":" + std::to_string(port);
|
||||
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
|
||||
}
|
||||
else
|
||||
{
|
||||
mAddress.clear();
|
||||
nixlAgentConfig nixlConfig{
|
||||
config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
|
||||
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
|
||||
}
|
||||
|
||||
std::string nixlBackend = common::getEnvNixlBackend();
|
||||
// List of supported backends - extend this list as new backends are added
|
||||
@ -645,7 +686,8 @@ void NixlLoopbackAgent::executeLoopbackRequest(
|
||||
|
||||
std::unique_ptr<TransferStatus> status = this->submitLoopbackRequests(memoryDescs, fileDescs, isOffload);
|
||||
TLLM_CHECK_WITH_INFO(status != nullptr, "submitLoopbackRequests failed");
|
||||
status->wait();
|
||||
TransferState transferState = status->wait();
|
||||
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "submitLoopbackRequests failed");
|
||||
|
||||
this->deregisterMemory(memoryDescs);
|
||||
this->deregisterFiles(fileDescs);
|
||||
|
||||
@ -45,7 +45,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool isCompleted() const override;
|
||||
|
||||
void wait() const override;
|
||||
[[nodiscard]] TransferState wait(int64_t timeout_ms = -1) const override;
|
||||
|
||||
private:
|
||||
nixlAgent* mRawAgent{};
|
||||
|
||||
@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio
|
||||
auto req = item.request;
|
||||
if (req->isDisaggContextCompleteState())
|
||||
{
|
||||
// If lastBlockId was tracked, unpin it. Otherwise, just terminate.
|
||||
// If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate.
|
||||
auto kvMgr = mModel->getKVCacheManager();
|
||||
if (kvMgr && item.lastBlockId.has_value())
|
||||
if (kvMgr && !item.pinnedBlockIds.empty())
|
||||
{
|
||||
kvMgr->unpinBlocksById(item.lastBlockId.value());
|
||||
kvMgr->unpinBlocksById(item.pinnedBlockIds);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses(
|
||||
// move the in transmission requests to another tracker
|
||||
if (llmReq->isDisaggContextTransmissionState())
|
||||
{
|
||||
std::optional<SizeType32> lastBlockId{};
|
||||
std::vector<SizeType32> pinnedBlockIds{};
|
||||
auto kvMgr = mModel->getKVCacheManager();
|
||||
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
|
||||
{
|
||||
lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
|
||||
pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
|
||||
mModel->terminateRequest(llmReq);
|
||||
}
|
||||
inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId});
|
||||
inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds});
|
||||
}
|
||||
finishedRequests.push_back(*it);
|
||||
it = activeRequests.erase(it);
|
||||
|
||||
@ -80,12 +80,12 @@ class Executor::Impl
|
||||
using RequestList = std::list<LlmRequestPtr>;
|
||||
|
||||
// When block reuse is enabled for context worker for disaggregated serving,
|
||||
// we need to store the last block id so that we can unpin the block when
|
||||
// we need to store the pinned block ids so that we can unpin them when
|
||||
// the request is finished.
|
||||
struct InTransmissionItem
|
||||
{
|
||||
LlmRequestPtr request;
|
||||
std::optional<SizeType32> lastBlockId;
|
||||
std::vector<SizeType32> pinnedBlockIds;
|
||||
};
|
||||
|
||||
using InTransList = std::list<InTransmissionItem>;
|
||||
|
||||
@ -70,9 +70,9 @@ struct LamportComm
|
||||
{
|
||||
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
|
||||
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
|
||||
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
|
||||
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
|
||||
flag_value = *flag_ptr;
|
||||
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
|
||||
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
|
||||
clear_size = *clear_ptr;
|
||||
int data_offset = flag_value % 3;
|
||||
int clear_offset = (flag_value + 2) % 3;
|
||||
@ -88,7 +88,7 @@ struct LamportComm
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void update(int new_clear_size)
|
||||
__device__ __forceinline__ void update(int64_t new_clear_size)
|
||||
{
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0)
|
||||
{
|
||||
@ -103,10 +103,10 @@ struct LamportComm
|
||||
|
||||
int* counter_ptr;
|
||||
int* flag_ptr;
|
||||
int* clear_ptr;
|
||||
int64_t* clear_ptr;
|
||||
uint8_t* data_bufs[NRanks];
|
||||
uint8_t* clear_buf;
|
||||
int clear_size;
|
||||
int64_t clear_size;
|
||||
int flag_value;
|
||||
};
|
||||
|
||||
|
||||
@ -21,18 +21,18 @@ TRTLLM_NAMESPACE_BEGIN
|
||||
namespace kernels::ar_fusion
|
||||
{
|
||||
|
||||
__global__ void lamport_initialize_kernel(float* ptr, int size)
|
||||
__global__ void lamport_initialize_kernel(float* ptr, size_t size)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
if (idx >= size)
|
||||
return;
|
||||
ptr[idx] = -0.f;
|
||||
}
|
||||
|
||||
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
|
||||
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream)
|
||||
{
|
||||
int grid_size = (bytes + 127) / 128;
|
||||
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
|
||||
int grid_size = static_cast<int>((bytes + 1023) / 1024);
|
||||
lamport_initialize_kernel<<<grid_size, 1024, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
|
||||
}
|
||||
|
||||
Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
|
||||
@ -45,10 +45,11 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
|
||||
int device_id;
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
m_buffer_mgr = std::make_shared<tensorrt_llm::runtime::BufferManager>(m_cuda_stream);
|
||||
int buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
|
||||
int flag_size = tp_size * kBarrierFlagCount * sizeof(int);
|
||||
int lamport_comm_size = tp_size * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
|
||||
int lamport_buffer_size = 3 * lamport_comm_size;
|
||||
size_t buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
|
||||
size_t flag_size = tp_size * kBarrierFlagCount * sizeof(int);
|
||||
size_t lamport_comm_size
|
||||
= static_cast<size_t>(tp_size) * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
|
||||
size_t lamport_buffer_size = 3 * lamport_comm_size;
|
||||
for (auto size : {buffer_size, flag_size, lamport_buffer_size})
|
||||
{
|
||||
m_ipc_mem_handles.emplace_back(size, *m_buffer_mgr, m_world_config, p2p_supported);
|
||||
@ -61,20 +62,20 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
|
||||
workspace.push_back(ipc_mem_handle.getCommPtrs()[r]);
|
||||
}
|
||||
}
|
||||
// atomic flag read counter
|
||||
// kernel_flag_ptr[0] = 0;
|
||||
// non-lamport flag
|
||||
// kernel_flag_ptr[1] = 0;
|
||||
// lamport flag
|
||||
// kernel_flag_ptr[2] = 0;
|
||||
// lamport triple buffer offset
|
||||
// kernel_flag_ptr[3] = lamport_comm_size;
|
||||
// lamport clear size
|
||||
// kernel_flag_ptr[4] = 0;
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 5 * sizeof(int)));
|
||||
std::vector<int> h_data{0, 0, 0, lamport_comm_size, 0};
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_data.data(), 5 * sizeof(int), cudaMemcpyHostToDevice));
|
||||
// flag_buffer[0], atomic flag read counter
|
||||
// flag_buffer[1], non-lamport flag
|
||||
// flag_buffer[2], lamport flag
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 3 * sizeof(int)));
|
||||
std::vector<int> h_flag_data{0, 0, 0};
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_flag_data.data(), 3 * sizeof(int), cudaMemcpyHostToDevice));
|
||||
workspace.push_back(m_flag_d_ptr);
|
||||
// layout_buffer[0], clear size for next lamport kernel
|
||||
// layout_buffer[1], triple buffer offset for lamport kernel
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&m_layout_d_ptr, 2 * sizeof(int64_t)));
|
||||
std::vector<int64_t> h_layout_data{0, static_cast<int64_t>(lamport_comm_size)};
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(m_layout_d_ptr, h_layout_data.data(), 2 * sizeof(int64_t), cudaMemcpyHostToDevice));
|
||||
workspace.push_back(m_layout_d_ptr);
|
||||
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&m_workspace, workspace.size() * sizeof(void*)));
|
||||
TLLM_CUDA_CHECK(
|
||||
cudaMemcpy(m_workspace, workspace.data(), workspace.size() * sizeof(void*), cudaMemcpyHostToDevice));
|
||||
@ -87,6 +88,10 @@ Workspace::~Workspace()
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFree(m_flag_d_ptr));
|
||||
}
|
||||
if (m_layout_d_ptr)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFree(m_layout_d_ptr));
|
||||
}
|
||||
if (m_workspace)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFree(m_workspace));
|
||||
|
||||
@ -41,9 +41,10 @@ private:
|
||||
void* m_workspace;
|
||||
std::shared_ptr<tensorrt_llm::runtime::CudaStream> m_cuda_stream;
|
||||
void* m_flag_d_ptr;
|
||||
void* m_layout_d_ptr;
|
||||
};
|
||||
|
||||
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream);
|
||||
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream);
|
||||
} // namespace kernels::ar_fusion
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
|
||||
@ -230,59 +230,62 @@ inline __device__ __host__ T divUp(T m, T n)
|
||||
// Return (block_size, cluster_size, loads_per_thread)
|
||||
std::tuple<int, int, int> adjustGridConfig(int numTokens, int dim, int eltsPerThread)
|
||||
{
|
||||
// Start with preferred block_size and cluster_size
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
int clusterSize = 8;
|
||||
#else
|
||||
int clusterSize = 1;
|
||||
#endif
|
||||
static int SM = tensorrt_llm::common::getSMVersion();
|
||||
|
||||
int clusterSize = SM >= 90 ? 8 : 1;
|
||||
int blockSize = 128;
|
||||
// ========================== Adjust the grid configuration ==========================
|
||||
int threadsNeeded = divUp(dim, eltsPerThread);
|
||||
int loadsPerThread = 1;
|
||||
|
||||
blockSize = divUp(threadsNeeded, clusterSize);
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
|
||||
if (clusterSize > 1)
|
||||
{
|
||||
clusterSize /= 2;
|
||||
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
|
||||
{
|
||||
clusterSize /= 2;
|
||||
}
|
||||
blockSize = divUp(threadsNeeded, clusterSize);
|
||||
while (blockSize < 128 && clusterSize >= 2)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
int smCount = getMultiProcessorCount();
|
||||
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
}
|
||||
blockSize = divUp(threadsNeeded, clusterSize);
|
||||
while (blockSize < 128 && clusterSize >= 2)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
int smCount = getMultiProcessorCount();
|
||||
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
|
||||
{
|
||||
blockSize *= 2;
|
||||
clusterSize /= 2;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Trying to scale up use multiple loads or CGA
|
||||
while (blockSize > 1024)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
if (clusterSize < 8)
|
||||
// Scale up with CGA if supported
|
||||
if (SM >= 90)
|
||||
{
|
||||
clusterSize = clusterSize << 1;
|
||||
if (clusterSize < 8)
|
||||
{
|
||||
clusterSize = clusterSize << 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
|
||||
if (loadsPerThread < 8)
|
||||
{
|
||||
loadsPerThread += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (loadsPerThread < 8)
|
||||
{
|
||||
loadsPerThread += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
blockSize = divUp(threadsNeeded, clusterSize * loadsPerThread);
|
||||
}
|
||||
return {blockSize, clusterSize, loadsPerThread};
|
||||
@ -420,9 +423,9 @@ __global__ void __launch_bounds__(1024) oneshotAllreduceFusionKernel(T* outputPt
|
||||
}
|
||||
float blockSum = blockReduceSum<float, true>(threadSum);
|
||||
|
||||
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
|
||||
float fullSum = blockSum;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
|
||||
namespace cg = cooperative_groups;
|
||||
cg::cluster_group cluster = cg::this_cluster();
|
||||
int const numBlocks = cluster.num_blocks();
|
||||
@ -459,6 +462,8 @@ using detail::adjustGridConfig;
|
||||
|
||||
void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
{
|
||||
|
||||
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
|
||||
int const numTokens = params.numTokens;
|
||||
int const tokenDim = params.tokenDim;
|
||||
int const eltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
|
||||
@ -466,38 +471,31 @@ void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
auto [blockSize, clusterSize, loadsPerThread] = adjustGridConfig(numTokens, tokenDim, eltsPerThread);
|
||||
dim3 grid(numTokens, clusterSize, 1);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
|
||||
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
1024 * 8 * eltsPerThread);
|
||||
#else
|
||||
1024 * eltsPerThread);
|
||||
#endif
|
||||
|
||||
TLLM_LOG_DEBUG(
|
||||
"[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
|
||||
"loads_per_thread: %d, "
|
||||
"threads_needed: %d",
|
||||
numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, divUp(tokenDim, eltsPerThread));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
|
||||
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
|
||||
1024 * (kSMVersion >= 90 ? 8 : 1) * eltsPerThread);
|
||||
|
||||
cudaLaunchAttribute attrs[2];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
attrs[1].id = cudaLaunchAttributeClusterDimension;
|
||||
attrs[1].val.clusterDim.x = 1;
|
||||
attrs[1].val.clusterDim.y = clusterSize;
|
||||
attrs[1].val.clusterDim.z = 1;
|
||||
#endif
|
||||
|
||||
cudaLaunchConfig_t config
|
||||
{
|
||||
.gridDim = grid, .blockDim = blockSize, .dynamicSmemBytes = 0, .stream = params.stream, .attrs = attrs,
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
.numAttrs = 2,
|
||||
#else
|
||||
.numAttrs = 1,
|
||||
#endif
|
||||
cudaLaunchConfig_t config{
|
||||
.gridDim = grid,
|
||||
.blockDim = blockSize,
|
||||
.dynamicSmemBytes = 0,
|
||||
.stream = params.stream,
|
||||
.attrs = attrs,
|
||||
.numAttrs = kSMVersion >= 90 ? 2U : 1U,
|
||||
};
|
||||
|
||||
#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, RMSNORM) \
|
||||
@ -831,9 +829,9 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
|
||||
float blockSum = blockReduceSum<float, true>(threadSum);
|
||||
|
||||
float fullSum = blockSum;
|
||||
__shared__ float sharedVal[8];
|
||||
// Use CGA Reduction if supported
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__shared__ float sharedVal[8];
|
||||
int const numBlocks = cluster.num_blocks();
|
||||
if (numBlocks > 1)
|
||||
{
|
||||
@ -876,6 +874,11 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
|
||||
}
|
||||
constexpr int kELTS_SIZE = sizeof(T_IN);
|
||||
|
||||
// Issue ACQBLK at the end. Assuming preceding kernel will not modify the buffer_flags.
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
// Update the buffer pointers
|
||||
flag.waitAndUpdate({static_cast<uint32_t>(divUp<uint32_t>(numTokens, worldSize) * worldSize * dim * kELTS_SIZE),
|
||||
static_cast<uint32_t>(numTokens * dim * kELTS_SIZE), 0, 0});
|
||||
@ -883,6 +886,7 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
|
||||
|
||||
void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
{
|
||||
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
|
||||
int const numTokens = params.numTokens;
|
||||
int const tokenDim = params.tokenDim;
|
||||
int const numEltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
|
||||
@ -959,17 +963,13 @@ void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
|
||||
rnConfig.attrs = rnAttrs;
|
||||
rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
rnAttrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
||||
#ifndef DISABLE_CGA
|
||||
rnAttrs[1].id = cudaLaunchAttributeClusterDimension;
|
||||
rnAttrs[1].val.clusterDim.x = 1;
|
||||
rnAttrs[1].val.clusterDim.y = rnClusterSize;
|
||||
rnAttrs[1].val.clusterDim.z = 1;
|
||||
rnConfig.numAttrs = 2;
|
||||
#else
|
||||
rnConfig.numAttrs = 1;
|
||||
#endif
|
||||
rnConfig.numAttrs = (kSMVersion >= 90) ? 2U : 1U;
|
||||
|
||||
bool const rnUseCGA = rnClusterSize > 1;
|
||||
bool const rnUseCGA = kSMVersion >= 90 && rnClusterSize > 1;
|
||||
int const dimPadded = divUp(tokenDim, numEltsPerThread * rnNumThreads) * numEltsPerThread * rnNumThreads;
|
||||
int const iters = dimPadded / rnNumThreads;
|
||||
|
||||
|
||||
@ -31,9 +31,9 @@ struct LamportComm
|
||||
{
|
||||
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
|
||||
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
|
||||
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
|
||||
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
|
||||
flag_value = *flag_ptr;
|
||||
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
|
||||
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
|
||||
clear_size = *clear_ptr;
|
||||
int data_offset = flag_value % 3;
|
||||
int clear_offset = (flag_value + 2) % 3;
|
||||
@ -49,7 +49,7 @@ struct LamportComm
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void update(int new_clear_size)
|
||||
__device__ __forceinline__ void update(int64_t new_clear_size)
|
||||
{
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0)
|
||||
{
|
||||
@ -64,10 +64,10 @@ struct LamportComm
|
||||
|
||||
int* counter_ptr;
|
||||
int* flag_ptr;
|
||||
int* clear_ptr;
|
||||
int64_t* clear_ptr;
|
||||
uint8_t* data_bufs[NRanks];
|
||||
uint8_t* clear_buf;
|
||||
int clear_size;
|
||||
int64_t clear_size;
|
||||
int flag_value;
|
||||
};
|
||||
|
||||
|
||||
@ -48,6 +48,12 @@ namespace kernels::moe_comm
|
||||
#define SWITCH_TOP_K(top_k, TOP_K, ...) \
|
||||
switch (top_k) \
|
||||
{ \
|
||||
case 22: \
|
||||
{ \
|
||||
constexpr int TOP_K = 22; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 16: \
|
||||
{ \
|
||||
constexpr int TOP_K = 16; \
|
||||
@ -362,88 +368,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
int thread_idx = ThreadingPolicy::offset();
|
||||
int local_token_idx = ThreadingPolicy::token_idx();
|
||||
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
if (local_num_tokens == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare per-policy shared-memory tiles for this token
|
||||
extern __shared__ int smem[];
|
||||
int* smem_topk_target_ranks;
|
||||
int* smem_topk_send_indices;
|
||||
int warps_per_block = blockDim.x / warpSize;
|
||||
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
|
||||
{
|
||||
int lane_id = threadIdx.x / warpSize;
|
||||
smem_topk_target_ranks = smem + lane_id * TOP_K;
|
||||
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
|
||||
// Special case: If local_num_tokens == 0,
|
||||
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
|
||||
// Other threads should return.
|
||||
if (local_token_idx > 0)
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_topk_target_ranks = smem;
|
||||
smem_topk_send_indices = smem + TOP_K;
|
||||
}
|
||||
// Threads that do not have a token to process should return.
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
return;
|
||||
|
||||
uint64_t already_copied = 0;
|
||||
for (int k = 0; k < TOP_K; k++)
|
||||
{
|
||||
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
|
||||
// Use contiguous partitioning to determine target rank
|
||||
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
|
||||
|
||||
if (already_copied & (1ULL << target_rank))
|
||||
// Prepare per-policy shared-memory tiles for this token
|
||||
extern __shared__ int smem[];
|
||||
int* smem_topk_target_ranks;
|
||||
int* smem_topk_send_indices;
|
||||
int warps_per_block = blockDim.x / warpSize;
|
||||
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
|
||||
{
|
||||
int lane_id = threadIdx.x / warpSize;
|
||||
smem_topk_target_ranks = smem + lane_id * TOP_K;
|
||||
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_topk_target_ranks = smem;
|
||||
smem_topk_send_indices = smem + TOP_K;
|
||||
}
|
||||
|
||||
uint64_t already_copied = 0;
|
||||
for (int k = 0; k < TOP_K; k++)
|
||||
{
|
||||
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
|
||||
// Use contiguous partitioning to determine target rank
|
||||
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
|
||||
|
||||
if (already_copied & (1ULL << target_rank))
|
||||
{
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
|
||||
// Mirror to shared memory immediately
|
||||
smem_topk_target_ranks[k] = -1;
|
||||
smem_topk_send_indices[k] = -1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only one thread per warp should increment the counter
|
||||
int dst_token_idx;
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
|
||||
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
|
||||
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
|
||||
// Mirror to shared memory immediately
|
||||
smem_topk_target_ranks[k] = -1;
|
||||
smem_topk_send_indices[k] = -1;
|
||||
smem_topk_target_ranks[k] = target_rank;
|
||||
smem_topk_send_indices[k] = dst_token_idx;
|
||||
}
|
||||
continue;
|
||||
already_copied |= 1ULL << target_rank;
|
||||
}
|
||||
// Sync before dispatching data
|
||||
ThreadingPolicy::sync();
|
||||
|
||||
// Only one thread per warp should increment the counter
|
||||
int dst_token_idx;
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
|
||||
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
|
||||
// Mirror to shared memory immediately
|
||||
smem_topk_target_ranks[k] = target_rank;
|
||||
smem_topk_send_indices[k] = dst_token_idx;
|
||||
}
|
||||
already_copied |= 1ULL << target_rank;
|
||||
}
|
||||
// Sync before dispatching data
|
||||
ThreadingPolicy::sync();
|
||||
|
||||
// Read staged routing once into registers per thread
|
||||
int topk_target_ranks[TOP_K];
|
||||
int topk_send_indices[TOP_K];
|
||||
// Read staged routing once into registers per thread
|
||||
int topk_target_ranks[TOP_K];
|
||||
int topk_send_indices[TOP_K];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOP_K; ++k)
|
||||
{
|
||||
topk_target_ranks[k] = smem_topk_target_ranks[k];
|
||||
topk_send_indices[k] = smem_topk_send_indices[k];
|
||||
for (int k = 0; k < TOP_K; ++k)
|
||||
{
|
||||
topk_target_ranks[k] = smem_topk_target_ranks[k];
|
||||
topk_send_indices[k] = smem_topk_send_indices[k];
|
||||
}
|
||||
|
||||
// Perform a single source load and TOP_K fanout per payload
|
||||
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
|
||||
{
|
||||
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
|
||||
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
|
||||
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
|
||||
|
||||
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
|
||||
payload_idx, ptrs, topk_target_ranks, topk_send_indices);
|
||||
}
|
||||
|
||||
ThreadingPolicy::sync();
|
||||
}
|
||||
|
||||
// Perform a single source load and TOP_K fanout per payload
|
||||
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
|
||||
{
|
||||
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
|
||||
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
|
||||
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
|
||||
|
||||
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
|
||||
ptrs, topk_target_ranks, topk_send_indices);
|
||||
}
|
||||
|
||||
ThreadingPolicy::sync();
|
||||
|
||||
bool is_first_warp = threadIdx.x / warpSize == 0;
|
||||
if (is_first_warp)
|
||||
{
|
||||
@ -452,8 +468,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
bool is_last_token = false;
|
||||
if (lane_id == 0)
|
||||
{
|
||||
int cnt = atomicAdd(ptrs.local_token_counter, 1);
|
||||
is_last_token = cnt + 1 == local_num_tokens;
|
||||
if (local_num_tokens != 0)
|
||||
{
|
||||
int cnt = atomicAdd(ptrs.local_token_counter, 1);
|
||||
is_last_token = cnt + 1 == local_num_tokens;
|
||||
}
|
||||
else
|
||||
{
|
||||
is_last_token = true;
|
||||
}
|
||||
}
|
||||
is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
|
||||
|
||||
@ -523,7 +546,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
// Validate parameters
|
||||
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
|
||||
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
|
||||
TLLM_CHECK(params.local_num_tokens > 0);
|
||||
TLLM_CHECK(params.local_num_tokens >= 0);
|
||||
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
|
||||
|
||||
// Prepare kernel pointers struct
|
||||
@ -568,6 +591,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
if (params.one_block_per_token)
|
||||
{
|
||||
int grid_size = params.local_num_tokens;
|
||||
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
|
||||
if (grid_size == 0)
|
||||
{
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
|
||||
@ -577,6 +605,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
else
|
||||
{
|
||||
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
|
||||
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
|
||||
if (grid_size == 0)
|
||||
{
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
|
||||
@ -626,7 +659,70 @@ __device__ void vectorized_combine_impl(
|
||||
// Load directly into the per-k accumulator; reduce across k below
|
||||
acc[k].load(recv_buffer + base_token + offset);
|
||||
}
|
||||
if constexpr (TOP_K == 16)
|
||||
// Reduce acc[TOP_K] into acc[0]
|
||||
if constexpr (TOP_K == 22)
|
||||
{
|
||||
T* a0 = reinterpret_cast<T*>(&acc[0]);
|
||||
T* a1 = reinterpret_cast<T*>(&acc[1]);
|
||||
T* a2 = reinterpret_cast<T*>(&acc[2]);
|
||||
T* a3 = reinterpret_cast<T*>(&acc[3]);
|
||||
T* a4 = reinterpret_cast<T*>(&acc[4]);
|
||||
T* a5 = reinterpret_cast<T*>(&acc[5]);
|
||||
T* a6 = reinterpret_cast<T*>(&acc[6]);
|
||||
T* a7 = reinterpret_cast<T*>(&acc[7]);
|
||||
T* a8 = reinterpret_cast<T*>(&acc[8]);
|
||||
T* a9 = reinterpret_cast<T*>(&acc[9]);
|
||||
T* a10 = reinterpret_cast<T*>(&acc[10]);
|
||||
T* a11 = reinterpret_cast<T*>(&acc[11]);
|
||||
T* a12 = reinterpret_cast<T*>(&acc[12]);
|
||||
T* a13 = reinterpret_cast<T*>(&acc[13]);
|
||||
T* a14 = reinterpret_cast<T*>(&acc[14]);
|
||||
T* a15 = reinterpret_cast<T*>(&acc[15]);
|
||||
T* a16 = reinterpret_cast<T*>(&acc[16]);
|
||||
T* a17 = reinterpret_cast<T*>(&acc[17]);
|
||||
T* a18 = reinterpret_cast<T*>(&acc[18]);
|
||||
T* a19 = reinterpret_cast<T*>(&acc[19]);
|
||||
T* a20 = reinterpret_cast<T*>(&acc[20]);
|
||||
T* a21 = reinterpret_cast<T*>(&acc[21]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a1[j];
|
||||
a2[j] += a3[j];
|
||||
a4[j] += a5[j];
|
||||
a6[j] += a7[j];
|
||||
a8[j] += a9[j];
|
||||
a10[j] += a11[j];
|
||||
a12[j] += a13[j];
|
||||
a14[j] += a15[j];
|
||||
a16[j] += a17[j];
|
||||
a18[j] += a19[j];
|
||||
a20[j] += a21[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a2[j];
|
||||
a4[j] += a6[j];
|
||||
a8[j] += a10[j];
|
||||
a12[j] += a14[j];
|
||||
a16[j] += a18[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a4[j];
|
||||
a8[j] += a12[j];
|
||||
a16[j] += a20[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a8[j];
|
||||
a0[j] += a16[j];
|
||||
}
|
||||
}
|
||||
else if constexpr (TOP_K == 16)
|
||||
{
|
||||
T* a0 = reinterpret_cast<T*>(&acc[0]);
|
||||
T* a1 = reinterpret_cast<T*>(&acc[1]);
|
||||
@ -710,9 +806,7 @@ __device__ void vectorized_combine_impl(
|
||||
a0[j] += a8[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce acc[TOP_K] into acc[0]
|
||||
if constexpr (TOP_K == 8)
|
||||
else if constexpr (TOP_K == 8)
|
||||
{
|
||||
T* a0 = reinterpret_cast<T*>(&acc[0]);
|
||||
T* a1 = reinterpret_cast<T*>(&acc[1]);
|
||||
@ -897,9 +991,19 @@ __global__ void moeA2ACombineKernel(
|
||||
int local_token_idx = ThreadingPolicy::token_idx();
|
||||
int const size_per_token = elements_per_token * sizeof(T);
|
||||
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
if (local_num_tokens == 0)
|
||||
{
|
||||
return;
|
||||
// Special case: If local_num_tokens == 0,
|
||||
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
|
||||
// Other threads should return.
|
||||
if (local_token_idx > 0)
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Threads that do not have a token to process should return.
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
return;
|
||||
}
|
||||
|
||||
#if !DISABLE_SYNC_FOR_PROFILING
|
||||
@ -951,6 +1055,9 @@ __global__ void moeA2ACombineKernel(
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
if (local_num_tokens == 0)
|
||||
return;
|
||||
|
||||
// Get output location for this token (using src_data_ptrs[0] as output)
|
||||
T* token_output = static_cast<T*>(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token;
|
||||
|
||||
@ -1003,7 +1110,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
|
||||
// Validate parameters
|
||||
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
|
||||
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
|
||||
TLLM_CHECK(params.local_num_tokens > 0);
|
||||
TLLM_CHECK(params.local_num_tokens >= 0);
|
||||
TLLM_CHECK(params.elements_per_token > 0);
|
||||
|
||||
// Configure kernel launch
|
||||
@ -1011,6 +1118,15 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
|
||||
int const kWarpsPerBlock = kBlockSize / 32; // warpSize
|
||||
int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
|
||||
int grid_size_block = params.local_num_tokens;
|
||||
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
|
||||
if (grid_size_warp == 0)
|
||||
{
|
||||
grid_size_warp = 1;
|
||||
}
|
||||
if (grid_size_block == 0)
|
||||
{
|
||||
grid_size_block = 1;
|
||||
}
|
||||
|
||||
// Prepare kernel pointers struct for combine
|
||||
CombineKernelPointers kernel_ptrs = {}; // Zero-initialize
|
||||
|
||||
@ -26,7 +26,7 @@ namespace kernels::moe_comm
|
||||
{
|
||||
|
||||
// Configuration constants
|
||||
static constexpr int kMaxTopK = 16; // Maximum top-k experts per token
|
||||
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
|
||||
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
|
||||
static constexpr int kMaxRanks = 64; // Maximum supported EP size
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f6509dd36fb92554c6078595951a8de698d7bdaa07b9b817bfcdd255d4303bca
|
||||
size 687070
|
||||
oid sha256:4f1f3679968b8f6dea77f53534af9eb1348b6f476d4c3880833b41dd4cc9c803
|
||||
size 687860
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b22d606e19b52047ae67319d61f138562f2b81df08ccde3f8fa04f040d408d7a
|
||||
size 669688
|
||||
oid sha256:a0d7061b400ab387309af00ae12f7a840b5abb91757183f415ca18329bbdb358
|
||||
size 670478
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2a70e335677a1b0f9d98267fe7701735e42f105720403489276d48a4247ea1b5
|
||||
size 423835
|
||||
oid sha256:4a91ff0238b0c8f1d40f8441f22a60a2c64d344b8550de68737292ff449d1d7e
|
||||
size 426203
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8289200bf78517033295966e9dbdf5c647da9aa7089669ff473ba436fef6a798
|
||||
size 1230152
|
||||
oid sha256:4d094c39dbdd372166facb297a4a91be80fb231bf3cca89afa97e61cc725f67e
|
||||
size 1228572
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:97cc5f8d42d92332a92fa216847bbacccc7ef9f9d5208bd26585cd702d03fe57
|
||||
size 1725040
|
||||
oid sha256:1fe830d32459fd9a25d54e1d00a98720afd938d9e9042e2b5903f969e991d72d
|
||||
size 1721882
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1264927817c08da144e387a7258f6c6fe424c0ff159f3ab0d6ffa3c4e3947598
|
||||
size 375671
|
||||
oid sha256:09af1ef9197c628c4a31cc58276ee6dcfad03f751069a78b5242594f93ea8c97
|
||||
size 378039
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:950fb45e94ffc8e2ec9f5a4b682075be55cb85d6415b3eeb172ce2cf7d53220d
|
||||
size 1140954
|
||||
oid sha256:9e93bb514c30bc5a4cda8f402a386ab85d079f9b97aeff04788cf3c8a8cc87a6
|
||||
size 1137008
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ba97e1bf342788eaf74a78f542f870d3967214aed98b98600fae772aad5bad5f
|
||||
size 653960
|
||||
oid sha256:0dc47824dfc41004c5b243ce9f40eefeee15c69b88474e33ec13137ef56604e8
|
||||
size 651592
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:337cc83d1880b1496e2f054285472b693c181e081819f425ddf2ea45a5dfe9f4
|
||||
size 1130682
|
||||
oid sha256:c0f042eabb29ee9db7ddf9791840337a7544653b295e4b2a5068b7f80bcd8251
|
||||
size 1128314
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:859ffffa18f1c9c8068a1cfedec487c2e0eab84af2c3720eaa7bb2a044ea16f6
|
||||
size 1534006
|
||||
oid sha256:7a9d887dd0acea6d82a25e0dda908f4c5421eaa1ddbfeeb49d382c079156d67e
|
||||
size 1535586
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:02bc55faacb50d0501c590ed11b40d802b374618cbde58db725cc67495762064
|
||||
size 698136
|
||||
oid sha256:22a7eaab8e44194acd83621e5546f164ad9cbeda8b67867f864a235036a03931
|
||||
size 690242
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:510d6c9942dea4bef976c2307fc63f1d7341d78ad8b41cca3bf80bae0a377575
|
||||
size 380847
|
||||
oid sha256:e22fe2dde7f5542975db7517b37cdce0eaa656fed2bc58378b37a872c54a43ef
|
||||
size 374533
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d0e0d34e15f533f756ac4ad6ef8889e5ed7556d859b6263509f608f2e7194e0a
|
||||
size 964134
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6fd7941b92a10c3116b3d93b50ce94d90627ed020e1aa4263b2c46926db60250
|
||||
size 1008328
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04439f4bdd5bf15dce0d59e455545236ed5b98c963a9b491c40d473eb766a04f
|
||||
size 988580
|
||||
oid sha256:ec624d7dceea5234b9dd4e43125f271e46ed4f2a4118837a23e00eb89571dcb2
|
||||
size 985422
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user