mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 08:45:05 +08:00
Compare commits
926 Commits
v1.2.0rc6.
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c64bc14719 | ||
|
|
957f803dd2 | ||
|
|
6157f30b06 | ||
|
|
ab941afa2e | ||
|
|
1c065fbb3e | ||
|
|
fedd7178d1 | ||
|
|
2450188808 | ||
|
|
cc4511997a | ||
|
|
08c7103fc4 | ||
|
|
d72f8098fe | ||
|
|
f3d784c6f6 | ||
|
|
fcb7bea07f | ||
|
|
361ff36784 | ||
|
|
59b6bee7e6 | ||
|
|
17e6062690 | ||
|
|
2b4ef3a014 | ||
|
|
ebd859cf61 | ||
|
|
435ea36977 | ||
|
|
5e47e6970b | ||
|
|
592988ebdb | ||
|
|
80708ba231 | ||
|
|
d8e7c61ea9 | ||
|
|
80235e53cf | ||
|
|
5d73194ffb | ||
|
|
d9f787a8d2 | ||
|
|
ed404f9298 | ||
|
|
b003355050 | ||
|
|
144188c2c4 | ||
|
|
0a9ddf8c17 | ||
|
|
29e44dd749 | ||
|
|
2989bf5b39 | ||
|
|
4debf153d8 | ||
|
|
b4e9669d2c | ||
|
|
f164669c04 | ||
|
|
26901e4aa0 | ||
|
|
19a3031ecb | ||
|
|
052fe2f7f6 | ||
|
|
37c53425c1 | ||
|
|
b67dcd8fef | ||
|
|
6837e73219 | ||
|
|
0ee757e03a | ||
|
|
ca499d600d | ||
|
|
d0e7ba102e | ||
|
|
db35119c7c | ||
|
|
2565f0f4e4 | ||
|
|
45d3792245 | ||
|
|
dd74f90914 | ||
|
|
5130cbd73e | ||
|
|
9c2d23c2e5 | ||
|
|
07cd3d4ff2 | ||
|
|
cb1d8d130f | ||
|
|
4b2b1d146b | ||
|
|
421eb9e39c | ||
|
|
ef7830d137 | ||
|
|
11d79aa875 | ||
|
|
31cdbdfd72 | ||
|
|
d0f3c412ff | ||
|
|
3c1323442b | ||
|
|
31314b9fed | ||
|
|
219195688c | ||
|
|
12085536df | ||
|
|
936220e746 | ||
|
|
e0b11d6ea0 | ||
|
|
ca9537e17c | ||
|
|
42648734b8 | ||
|
|
632c039aea | ||
|
|
58165d5394 | ||
|
|
2c4a4c7b94 | ||
|
|
2d5ebb3fe8 | ||
|
|
7a103035be | ||
|
|
c47ff4da43 | ||
|
|
eed9c16560 | ||
|
|
e8b860965b | ||
|
|
18c992efb1 | ||
|
|
8ebd6056fa | ||
|
|
3741bb2bb4 | ||
|
|
5ea6888dda | ||
|
|
a982554190 | ||
|
|
860054c859 | ||
|
|
f320bc8a9c | ||
|
|
a7c4005a3d | ||
|
|
411fa9ff87 | ||
|
|
7d992972b2 | ||
|
|
d6e49542bd | ||
|
|
cf02456613 | ||
|
|
c7689df152 | ||
|
|
6e0659dc4d | ||
|
|
eac56b793e | ||
|
|
be88fe33be | ||
|
|
adc0d82500 | ||
|
|
21cdc39e83 | ||
|
|
2a4e70b4a9 | ||
|
|
8a74ccc57e | ||
|
|
5f4df89109 | ||
|
|
c233692485 | ||
|
|
17cc1c13d6 | ||
|
|
c3cdc93211 | ||
|
|
8b2dc57823 | ||
|
|
0c8b5221b4 | ||
|
|
a2fb5afecf | ||
|
|
d50f010fa9 | ||
|
|
85919d9517 | ||
|
|
4fc3644705 | ||
|
|
b5508ed75b | ||
|
|
f33086914f | ||
|
|
af68c29d3d | ||
|
|
fe4c690b6c | ||
|
|
e76b634251 | ||
|
|
092f4ce774 | ||
|
|
c68d916b6f | ||
|
|
ea81a03dd1 | ||
|
|
4a743338c3 | ||
|
|
e719721a60 | ||
|
|
100bfdc516 | ||
|
|
c37531c3f7 | ||
|
|
9384cf8458 | ||
|
|
03b635bb08 | ||
|
|
1524c172a4 | ||
|
|
5f8b1b8cbb | ||
|
|
1ba039f044 | ||
|
|
abb8106c01 | ||
|
|
0ead17bb85 | ||
|
|
d348dd95a7 | ||
|
|
fd4e6132e5 | ||
|
|
d50010cd1f | ||
|
|
6c4e0c3dbe | ||
|
|
c659280445 | ||
|
|
59f59efb83 | ||
|
|
90ea6c1e09 | ||
|
|
196d94a419 | ||
|
|
2b60cc181c | ||
|
|
540fb0f29e | ||
|
|
b3e4ddc953 | ||
|
|
31db399042 | ||
|
|
ab73f6ebc6 | ||
|
|
635d65f9fe | ||
|
|
ad8f6748a3 | ||
|
|
fe9192f120 | ||
|
|
b464c75056 | ||
|
|
f7cf25748b | ||
|
|
03b38e9fbf | ||
|
|
ffc0f54959 | ||
|
|
408d610877 | ||
|
|
18e611da77 | ||
|
|
f9eed3ecc2 | ||
|
|
b1268e1b37 | ||
|
|
66caa67357 | ||
|
|
383c5921c2 | ||
|
|
09807918c7 | ||
|
|
df1c1a23d4 | ||
|
|
9644f024bd | ||
|
|
d160439ef9 | ||
|
|
639051e98b | ||
|
|
2e6d9350fa | ||
|
|
b98f3fca20 | ||
|
|
86e867297e | ||
|
|
5521c7b7e7 | ||
|
|
712dcd31a9 | ||
|
|
a9d4927235 | ||
|
|
a7494a5ff4 | ||
|
|
d778b26062 | ||
|
|
e52eb82780 | ||
|
|
d3d951d837 | ||
|
|
7d235cfb23 | ||
|
|
eae480b713 | ||
|
|
719e82c429 | ||
|
|
e483c7263d | ||
|
|
0d18b2d7a4 | ||
|
|
9601b17459 | ||
|
|
d9b936be94 | ||
|
|
4c1d9d0c10 | ||
|
|
36cb5f8c93 | ||
|
|
8447a96c29 | ||
|
|
ada4a3a28e | ||
|
|
9091a193a8 | ||
|
|
ada463d15d | ||
|
|
4adf76d860 | ||
|
|
0bd4630cd1 | ||
|
|
ad2d1df4a9 | ||
|
|
d9fd8cc951 | ||
|
|
767b8dcab3 | ||
|
|
d90a8e5700 | ||
|
|
925d911fc0 | ||
|
|
e2bd9cce1e | ||
|
|
f6fff18142 | ||
|
|
3d8c1a51bd | ||
|
|
f0ca62b175 | ||
|
|
02b80bfd58 | ||
|
|
de6931bbfd | ||
|
|
04b7db3ab5 | ||
|
|
588db0ed64 | ||
|
|
5d522295e9 | ||
|
|
f9e6045f39 | ||
|
|
f9c4bdf6cf | ||
|
|
8f90330239 | ||
|
|
710d6ef668 | ||
|
|
2532eb5adc | ||
|
|
20946554f6 | ||
|
|
a56aaa585e | ||
|
|
b7767f682f | ||
|
|
03f51bb767 | ||
|
|
e308eb50f4 | ||
|
|
304dc6f3c0 | ||
|
|
12b4ebd0ad | ||
|
|
061d7879d3 | ||
|
|
13420178fc | ||
|
|
897eb0df2b | ||
|
|
585fbb2734 | ||
|
|
3ef8a4639b | ||
|
|
cd7762a2fa | ||
|
|
f1b85fea4c | ||
|
|
13b0ab9c0e | ||
|
|
d9aef94431 | ||
|
|
fa5c3ead05 | ||
|
|
de465efc5f | ||
|
|
d31482686c | ||
|
|
7e5e5b90b9 | ||
|
|
dd0a5491ba | ||
|
|
40d6f23dad | ||
|
|
68a18f7a3a | ||
|
|
ccdd8461ac | ||
|
|
fafc22e3d4 | ||
|
|
bc2487bc2c | ||
|
|
4d282bd7c1 | ||
|
|
6c2ecad2fe | ||
|
|
8fd22ac72d | ||
|
|
2a5b8800e1 | ||
|
|
0306c0f12c | ||
|
|
d3df3f6feb | ||
|
|
9909dca6fa | ||
|
|
77afcbddae | ||
|
|
3800abe26e | ||
|
|
fef0e4b17d | ||
|
|
b00e8338ec | ||
|
|
ea49afdf0b | ||
|
|
1c8f8bed00 | ||
|
|
0350922c5f | ||
|
|
2e757e8151 | ||
|
|
278ced972b | ||
|
|
d1e4527c06 | ||
|
|
7910d4d2a9 | ||
|
|
6bace84167 | ||
|
|
531f85dc9b | ||
|
|
baf9f7b4dc | ||
|
|
492ed27cdf | ||
|
|
97ab014bdb | ||
|
|
5a97374f3c | ||
|
|
4af47208d8 | ||
|
|
f42a6cbae0 | ||
|
|
5d7411e131 | ||
|
|
a669a163ff | ||
|
|
53cb762ee5 | ||
|
|
5ff244ce54 | ||
|
|
9959a5c78e | ||
|
|
f2dd0ee128 | ||
|
|
322471cdd7 | ||
|
|
4f0c1b2489 | ||
|
|
ef268e2062 | ||
|
|
6506d63466 | ||
|
|
29a203aedb | ||
|
|
e1e3bb8592 | ||
|
|
144b61715f | ||
|
|
54ba056924 | ||
|
|
dbad94715b | ||
|
|
e033929221 | ||
|
|
0ad87895f5 | ||
|
|
a4880ffdbb | ||
|
|
4345636b04 | ||
|
|
ab7dd34bbe | ||
|
|
7d31532850 | ||
|
|
80dd6e70c6 | ||
|
|
c7a86f89de | ||
|
|
21d475a391 | ||
|
|
f6dab8388d | ||
|
|
91528365a9 | ||
|
|
34a730aaf7 | ||
|
|
24ac86c485 | ||
|
|
e20f9a9c72 | ||
|
|
6fcbf15fb8 | ||
|
|
f03908cf9e | ||
|
|
4e10bf8950 | ||
|
|
393c3d259e | ||
|
|
744a955cbb | ||
|
|
0ffa77af51 | ||
|
|
e70a55bd94 | ||
|
|
29647d9446 | ||
|
|
38bcee189c | ||
|
|
3e17ee4e38 | ||
|
|
d008494232 | ||
|
|
dc5eda546b | ||
|
|
a7748ceb57 | ||
|
|
1c2e415b3a | ||
|
|
30348b2753 | ||
|
|
c26a8f764c | ||
|
|
6c1862fb33 | ||
|
|
f25a2c53bb | ||
|
|
bae2fac834 | ||
|
|
ff3a494f5c | ||
|
|
7f8c260601 | ||
|
|
552aa32aa2 | ||
|
|
b575184fca | ||
|
|
d6f76d2fae | ||
|
|
fae4985797 | ||
|
|
6b251cc7fa | ||
|
|
93ae8a14ab | ||
|
|
069ad30bdb | ||
|
|
ea5d811aec | ||
|
|
c761b68481 | ||
|
|
ca9f70f78c | ||
|
|
5553391c5e | ||
|
|
4a206351bb | ||
|
|
da43a28b01 | ||
|
|
df8be0c50c | ||
|
|
ff0dd6076e | ||
|
|
43b8a5561c | ||
|
|
00f341be49 | ||
|
|
ce556290c9 | ||
|
|
ce37e27066 | ||
|
|
5d7a5e6800 | ||
|
|
e405468230 | ||
|
|
5efee01da1 | ||
|
|
a3a3ceb17f | ||
|
|
d3406cb515 | ||
|
|
c8f1745a6e | ||
|
|
2d8245d125 | ||
|
|
d2b5954aea | ||
|
|
ffab217974 | ||
|
|
45d7022cc3 | ||
|
|
72ef732bcf | ||
|
|
fd7fd8c39d | ||
|
|
c98c286c0f | ||
|
|
ae58a7ed20 | ||
|
|
bcd2dc490c | ||
|
|
18f63dfcec | ||
|
|
44aa6c3b8e | ||
|
|
0f7ec033f7 | ||
|
|
8959c41d8b | ||
|
|
4ebc1b1596 | ||
|
|
4df0ca8bd1 | ||
|
|
af49fbdf65 | ||
|
|
25bdc30162 | ||
|
|
2b3bb2e9b0 | ||
|
|
4b833492fb | ||
|
|
aa410c57bc | ||
|
|
f02948d956 | ||
|
|
93e7ae73ea | ||
|
|
0c393ebc69 | ||
|
|
d548b29a41 | ||
|
|
6f07fa81d7 | ||
|
|
9fcc93ea7b | ||
|
|
9d65b8bf24 | ||
|
|
78a008d61a | ||
|
|
da967d0bd7 | ||
|
|
58dc4bea9c | ||
|
|
cf88da7eca | ||
|
|
1fbbb1f3cd | ||
|
|
b560598c79 | ||
|
|
f4b52d3b78 | ||
|
|
1d68fab49c | ||
|
|
54768f3f2c | ||
|
|
43f2b51e94 | ||
|
|
ae114ec7cf | ||
|
|
51c7a06da6 | ||
|
|
0f7192c7fe | ||
|
|
31d04dfa12 | ||
|
|
ea928f62af | ||
|
|
d793bd973d | ||
|
|
2146c23786 | ||
|
|
d8e6e22060 | ||
|
|
d43be7b65e | ||
|
|
944c304bbb | ||
|
|
9adef4eb28 | ||
|
|
b3146d095d | ||
|
|
30ffa58b54 | ||
|
|
a218cf02fd | ||
|
|
5e34112b27 | ||
|
|
9beb971827 | ||
|
|
1dc49b266e | ||
|
|
cdb9ffd0ab | ||
|
|
128d4ac5be | ||
|
|
0243abee22 | ||
|
|
0b3092e144 | ||
|
|
6e72aff866 | ||
|
|
9ce0511d86 | ||
|
|
9462d90ec7 | ||
|
|
fd2af8d58a | ||
|
|
ff0775408d | ||
|
|
be4a431ffd | ||
|
|
895bb94b3d | ||
|
|
70caa779a4 | ||
|
|
415739711f | ||
|
|
f3a41c8d94 | ||
|
|
0434db5bf7 | ||
|
|
bd56b4e1e3 | ||
|
|
c2a9e66dff | ||
|
|
635cbf01ba | ||
|
|
5450485bec | ||
|
|
8cf8fbbe16 | ||
|
|
f91ea37a13 | ||
|
|
bf7303c7f1 | ||
|
|
165dd360b9 | ||
|
|
9feebb3a27 | ||
|
|
a4152c80f6 | ||
|
|
1592dfab6d | ||
|
|
d60d6ff6fd | ||
|
|
87073d1ce4 | ||
|
|
9116dfbacd | ||
|
|
ffd2ed51dd | ||
|
|
ccf4d79c6c | ||
|
|
c381790d15 | ||
|
|
2f3b2a3172 | ||
|
|
3c39b1faa9 | ||
|
|
26c23cf99f | ||
|
|
3c8ed19440 | ||
|
|
c6163e2b70 | ||
|
|
864b61cadd | ||
|
|
66b239a9a9 | ||
|
|
2db3d7eeba | ||
|
|
e61c942d1f | ||
|
|
ae8f74b620 | ||
|
|
3a894951e7 | ||
|
|
338b29d5ae | ||
|
|
c8a200486d | ||
|
|
eb326073d8 | ||
|
|
58311b2345 | ||
|
|
47e0ec2527 | ||
|
|
99e8cb0999 | ||
|
|
fc467d06c3 | ||
|
|
4c8468c5d3 | ||
|
|
26bc16842e | ||
|
|
44c5af88dc | ||
|
|
f3a985ce27 | ||
|
|
dbb858ae0c | ||
|
|
c6320d924d | ||
|
|
066fa4cd93 | ||
|
|
ed95e70150 | ||
|
|
64ff5cac52 | ||
|
|
442d2e8a15 | ||
|
|
cc0bbde745 | ||
|
|
32ab809f36 | ||
|
|
baa250d1d6 | ||
|
|
935c174283 | ||
|
|
df845a028b | ||
|
|
68ab1a47c4 | ||
|
|
e97af45556 | ||
|
|
a6a63f5a36 | ||
|
|
4f04532ce7 | ||
|
|
9879400479 | ||
|
|
4d2916d683 | ||
|
|
b64052539d | ||
|
|
3aaed62cfc | ||
|
|
e1cc8d2337 | ||
|
|
a11f0dbd61 | ||
|
|
0af1a0e478 | ||
|
|
f8c26409f9 | ||
|
|
0096b50ba0 | ||
|
|
7bf4dd9f63 | ||
|
|
cef67b4f8d | ||
|
|
b65560fc32 | ||
|
|
3d16daf696 | ||
|
|
56073f501a | ||
|
|
24d7e499b4 | ||
|
|
069ad68d3c | ||
|
|
0b748d5bba | ||
|
|
b6acd96616 | ||
|
|
0cfd08745c | ||
|
|
cfebfbb505 | ||
|
|
cc43edc8f4 | ||
|
|
c4db030b88 | ||
|
|
722978b837 | ||
|
|
4f86c5f5ce | ||
|
|
6dfb8d7084 | ||
|
|
0256c7234f | ||
|
|
b163e66182 | ||
|
|
03cdf5804f | ||
|
|
f001c4946d | ||
|
|
e2c3373749 | ||
|
|
7686fbbcbe | ||
|
|
8257b67ea5 | ||
|
|
6541e41c74 | ||
|
|
7b8b9ccbaf | ||
|
|
9f741fb254 | ||
|
|
ce561b6a8e | ||
|
|
7e2cbc0756 | ||
|
|
e3f27e06c7 | ||
|
|
ef838cc852 | ||
|
|
49c6f73554 | ||
|
|
5ad8cf6d5e | ||
|
|
0998a7bf20 | ||
|
|
dfac07c045 | ||
|
|
43b9db3364 | ||
|
|
93db0d5e18 | ||
|
|
3bc17e1aa3 | ||
|
|
ff277b591e | ||
|
|
cd55fb4551 | ||
|
|
683515b1bd | ||
|
|
71ccc07d2b | ||
|
|
e12a7119cf | ||
|
|
f4ace99218 | ||
|
|
22240e43eb | ||
|
|
7b3b6f1161 | ||
|
|
faa80e73fd | ||
|
|
62050b2381 | ||
|
|
f7de285a82 | ||
|
|
482b7b8837 | ||
|
|
15b43e8a14 | ||
|
|
94c7b69048 | ||
|
|
73d1840c12 | ||
|
|
0f2d61b8c6 | ||
|
|
5f9fc50233 | ||
|
|
211c44b951 | ||
|
|
968db53194 | ||
|
|
c99faaed06 | ||
|
|
01083b56bf | ||
|
|
35c24424f6 | ||
|
|
b10704428d | ||
|
|
582dec5bb5 | ||
|
|
babd5ecacc | ||
|
|
25148d3fee | ||
|
|
e9817461ba | ||
|
|
d8862505b9 | ||
|
|
272688c663 | ||
|
|
e7882d5c74 | ||
|
|
052c36ddd2 | ||
|
|
487287a412 | ||
|
|
287f6c2e0f | ||
|
|
c4da4fd462 | ||
|
|
15281de799 | ||
|
|
39cefd6125 | ||
|
|
f841b43cde | ||
|
|
92ae490410 | ||
|
|
07d9390e9b | ||
|
|
b65c515314 | ||
|
|
dd22324675 | ||
|
|
7305c61fc9 | ||
|
|
795e690bca | ||
|
|
d3f4fbb742 | ||
|
|
2acd03030a | ||
|
|
bc119f5644 | ||
|
|
ccdfa43a6e | ||
|
|
bf16fbd86c | ||
|
|
7b7f1e2ba1 | ||
|
|
6ee8dbfe0b | ||
|
|
7a47e29dcb | ||
|
|
6df2c8a074 | ||
|
|
c1b0b7350f | ||
|
|
38296a472b | ||
|
|
50c78179dd | ||
|
|
55580f8ec1 | ||
|
|
7d16f3a28b | ||
|
|
bdaee87895 | ||
|
|
e291a834db | ||
|
|
04b112651b | ||
|
|
50c22b80d7 | ||
|
|
7d41475954 | ||
|
|
2967d299fb | ||
|
|
ba1cb6831d | ||
|
|
bbe535fddf | ||
|
|
ba1037ca4a | ||
|
|
48b09e5a25 | ||
|
|
18a33764b5 | ||
|
|
dacc881993 | ||
|
|
a1385243e1 | ||
|
|
9f044b9dd9 | ||
|
|
bf7998f1b8 | ||
|
|
11da7e3605 | ||
|
|
3bd319dc8e | ||
|
|
8e806abac3 | ||
|
|
c5914f9085 | ||
|
|
54459377d2 | ||
|
|
3a9a00b544 | ||
|
|
5e0dbba0c9 | ||
|
|
2de22f1a70 | ||
|
|
c0e25e5418 | ||
|
|
c5d5af9e7f | ||
|
|
7f018c89e9 | ||
|
|
8e0d20d901 | ||
|
|
80649a8b78 | ||
|
|
0371cbfd88 | ||
|
|
b2e2538fcd | ||
|
|
3c65ec3c55 | ||
|
|
f6045fac09 | ||
|
|
f6c4dd885f | ||
|
|
6ab996d635 | ||
|
|
ff7eb93f31 | ||
|
|
38f249b479 | ||
|
|
82dfef2e56 | ||
|
|
fdbdbba540 | ||
|
|
d80f01d205 | ||
|
|
7295af68ba | ||
|
|
1c69aad850 | ||
|
|
ced88424ef | ||
|
|
627d306df9 | ||
|
|
2b72d33fdc | ||
|
|
4632a8642d | ||
|
|
80f261ea36 | ||
|
|
78bb245554 | ||
|
|
4a09acd012 | ||
|
|
4c498bfe58 | ||
|
|
c5331e6dbb | ||
|
|
6fcd4e7099 | ||
|
|
5df03b2ea7 | ||
|
|
d707286ca8 | ||
|
|
afa55c12b6 | ||
|
|
56e779d09f | ||
|
|
4092a87b6f | ||
|
|
489dd60312 | ||
|
|
e0331297a6 | ||
|
|
c0ae6bbdbe | ||
|
|
6511dbaea0 | ||
|
|
bea61bb17d | ||
|
|
dc6b743fb6 | ||
|
|
43839c7d9b | ||
|
|
8d4b09dac6 | ||
|
|
22c81cb5fa | ||
|
|
f57aab5255 | ||
|
|
30f8455d29 | ||
|
|
342a47bf47 | ||
|
|
f8b2a8fd30 | ||
|
|
b85c447ceb | ||
|
|
09d9878385 | ||
|
|
81f878c279 | ||
|
|
d736c7f290 | ||
|
|
7187afe7b9 | ||
|
|
e8cceb06b2 | ||
|
|
b130d58c88 | ||
|
|
7e88212d24 | ||
|
|
872210468b | ||
|
|
dc32bac9fc | ||
|
|
cbf8357e5f | ||
|
|
be5579633e | ||
|
|
a34aa63685 | ||
|
|
3fec7e411c | ||
|
|
1fbadd2dde | ||
|
|
4a1b2e23b3 | ||
|
|
6095c80e56 | ||
|
|
bb2f883296 | ||
|
|
bb6a3973aa | ||
|
|
00355b24b7 | ||
|
|
77be1b7572 | ||
|
|
037753f65b | ||
|
|
6a4bebcd01 | ||
|
|
7d62773c6c | ||
|
|
704f58dfbe | ||
|
|
6507087c3f | ||
|
|
df0b976b99 | ||
|
|
ab58d7cac1 | ||
|
|
2eaabd7461 | ||
|
|
1e828587e5 | ||
|
|
5108a69fc0 | ||
|
|
998527724c | ||
|
|
810249c304 | ||
|
|
22a1d31a27 | ||
|
|
1b1058279c | ||
|
|
3e98265682 | ||
|
|
596d4f16fb | ||
|
|
617f728903 | ||
|
|
aa1fe931de | ||
|
|
46f035befe | ||
|
|
9cae7277ea | ||
|
|
6b8ae6fa81 | ||
|
|
77712ed4ab | ||
|
|
82aaf98070 | ||
|
|
8a04c05079 | ||
|
|
536a8f6a9c | ||
|
|
846e54aa09 | ||
|
|
3b56548fcf | ||
|
|
4e50cb5708 | ||
|
|
91ff46d418 | ||
|
|
7a2dab8e85 | ||
|
|
6b71b03947 | ||
|
|
ea380ff45c | ||
|
|
db2614ef10 | ||
|
|
bedfff4f00 | ||
|
|
e98c27ee4f | ||
|
|
225d3a9001 | ||
|
|
a792c23dcf | ||
|
|
3749a2ce1c | ||
|
|
b1733d56f6 | ||
|
|
4931c5eb3a | ||
|
|
d272f1a9bc | ||
|
|
2f768b76f8 | ||
|
|
c63fad7d96 | ||
|
|
e7a4486294 | ||
|
|
c04cf4334e | ||
|
|
0937df2c68 | ||
|
|
5a8bfcbb50 | ||
|
|
a7fe043b13 | ||
|
|
aaf80be0f3 | ||
|
|
5773a4d775 | ||
|
|
656c705ff1 | ||
|
|
b5a1e10bc0 | ||
|
|
da0830670a | ||
|
|
82c1ba84a7 | ||
|
|
0517b62789 | ||
|
|
8e2065b4d9 | ||
|
|
e2f5455533 | ||
|
|
a65b0d4efa | ||
|
|
c4f27fa4c0 | ||
|
|
afc533193d | ||
|
|
a4dcc6a711 | ||
|
|
6ba04eba06 | ||
|
|
71b4a8aa60 | ||
|
|
5bd37ce41e | ||
|
|
0d1f5ad7a2 | ||
|
|
c0b3c2b919 | ||
|
|
59045a0e41 | ||
|
|
865992b86b | ||
|
|
9e7b50aefb | ||
|
|
45ffbf1f21 | ||
|
|
937f8f78a1 | ||
|
|
bdf6953ddc | ||
|
|
f3dd6da080 | ||
|
|
5e0e48144f | ||
|
|
098251648d | ||
|
|
f631b25c85 | ||
|
|
4a1b742aa0 | ||
|
|
5845951538 | ||
|
|
4868772ad7 | ||
|
|
9f5b750a93 | ||
|
|
0b75340223 | ||
|
|
edbcff0257 | ||
|
|
ff836d4f41 | ||
|
|
1bbe71b3ed | ||
|
|
9085021aa4 | ||
|
|
84d107b2f0 | ||
|
|
0d2e2718ce | ||
|
|
a23c6f1092 | ||
|
|
464847c6be | ||
|
|
ef1d4a40b5 | ||
|
|
d944430f96 | ||
|
|
73870ae4ad | ||
|
|
827d12caaf | ||
|
|
910a633066 | ||
|
|
fdc03684cc | ||
|
|
fad000589d | ||
|
|
1e9c153b4c | ||
|
|
6c1abf2d45 | ||
|
|
ed3a3097a4 | ||
|
|
34c2fd50a9 | ||
|
|
1f3afb8e6f | ||
|
|
ec8a388c25 | ||
|
|
74832a1895 | ||
|
|
1f0365da36 | ||
|
|
6732c76414 | ||
|
|
fb05cd769a | ||
|
|
cce7247815 | ||
|
|
6accdbc6a6 | ||
|
|
0f4ed90560 | ||
|
|
692d8f2023 | ||
|
|
3e0344a53d | ||
|
|
48fee8d0f6 | ||
|
|
f396ad83b0 | ||
|
|
fa4c7997c5 | ||
|
|
4944192eae | ||
|
|
966231d29c | ||
|
|
965578ca21 | ||
|
|
9cee32ab39 | ||
|
|
2f8d6d25a8 | ||
|
|
223411e988 | ||
|
|
270be801aa | ||
|
|
c59aa8bec5 | ||
|
|
ae6d5766ed | ||
|
|
55bc6a5ff8 | ||
|
|
ee07a7c55e | ||
|
|
1865020b6f | ||
|
|
93ac0bc1dc | ||
|
|
27976fce9c | ||
|
|
55f3cda66d | ||
|
|
c04563657e | ||
|
|
d70aeddc7f | ||
|
|
684b37df02 | ||
|
|
c5b0f9e436 | ||
|
|
bfc591994c | ||
|
|
4a5ef84dc2 | ||
|
|
14554ab3f3 | ||
|
|
819d03fa88 | ||
|
|
13ffe52ad0 | ||
|
|
f3f02315df | ||
|
|
db3430f589 | ||
|
|
7e4cef9def | ||
|
|
d8b5aeb061 | ||
|
|
46e4af5688 | ||
|
|
fe12faef81 | ||
|
|
cd5cd60ee4 | ||
|
|
8462cf6c96 | ||
|
|
97b38ac403 | ||
|
|
0ecdb69b93 | ||
|
|
53b81783b1 | ||
|
|
83e02ee335 | ||
|
|
182b3eb633 | ||
|
|
1d01214ff0 | ||
|
|
4ae6f6a46c | ||
|
|
7395ca93b6 | ||
|
|
c059e6caa1 | ||
|
|
a9eb5afc9f | ||
|
|
1f8ed71d5f | ||
|
|
16fd781e42 | ||
|
|
43178590d1 | ||
|
|
c4b36d31ff | ||
|
|
8614cd3439 | ||
|
|
e2891a6c77 | ||
|
|
ddac4d7379 | ||
|
|
69152c4e7c | ||
|
|
56ef97e06e | ||
|
|
ecea71ca7a | ||
|
|
f4f0fe85e9 | ||
|
|
534700ecd9 | ||
|
|
595daa5089 | ||
|
|
156f6453dc | ||
|
|
f6c3bc16b9 | ||
|
|
7b84e48e0f | ||
|
|
68cf5c7924 | ||
|
|
fc1f77eafc | ||
|
|
8c1cfc872b | ||
|
|
92d90fa29a | ||
|
|
0027a01ad5 | ||
|
|
06900a7f19 | ||
|
|
984c20e0b2 | ||
|
|
e284d0bf80 | ||
|
|
64bb1a5155 | ||
|
|
8408c40d8b | ||
|
|
871c6b435c | ||
|
|
522f1d2bc3 | ||
|
|
f2e00a75de | ||
|
|
3ddc9d2b48 | ||
|
|
48c875f8ea | ||
|
|
cc1323be24 | ||
|
|
59b05dc0a8 | ||
|
|
53db3b2612 | ||
|
|
77b591f73b | ||
|
|
d691371eaf | ||
|
|
5bc7ffe379 | ||
|
|
18f8b22956 | ||
|
|
621156ad44 | ||
|
|
1e82ff7a0c | ||
|
|
696f754ef4 | ||
|
|
648196f8ae | ||
|
|
f05af48bca | ||
|
|
0d2500c631 | ||
|
|
ccc64da287 | ||
|
|
12e1cb8d7e | ||
|
|
aaa87abf41 | ||
|
|
ba14a9308e | ||
|
|
0f308e95f9 | ||
|
|
a6a88985cf | ||
|
|
472fe497dc | ||
|
|
ea6cd76c55 | ||
|
|
c87f1a6b39 | ||
|
|
9e9523c3cc | ||
|
|
7421224d69 | ||
|
|
d30ee8101e | ||
|
|
237fd0eae4 | ||
|
|
f8501f3cc8 | ||
|
|
f0bd60a395 | ||
|
|
066b653940 | ||
|
|
2f139ee07e | ||
|
|
914dd39127 | ||
|
|
d274a4c5d3 | ||
|
|
5549067966 | ||
|
|
5266475014 | ||
|
|
4fc6036276 | ||
|
|
cd4b4f43fa | ||
|
|
5a611cb8f5 | ||
|
|
aa5dbb7ca5 | ||
|
|
5ae154022a | ||
|
|
b15f987972 | ||
|
|
a66eeab537 | ||
|
|
dcd3f7b5ea | ||
|
|
6c76148b56 | ||
|
|
77e37d9dd0 | ||
|
|
2ce785f39a | ||
|
|
21a93fbf9d | ||
|
|
3f25db9d3e | ||
|
|
3b3069b390 | ||
|
|
e75331480f | ||
|
|
7c82605327 | ||
|
|
bee9051484 | ||
|
|
20b69a982a | ||
|
|
5489d188a4 | ||
|
|
b882393d69 | ||
|
|
dfa11d810e | ||
|
|
7b71ff6b8a | ||
|
|
27e49e2904 | ||
|
|
9f6abaf59f | ||
|
|
7b51e3cedb | ||
|
|
dd8ce68c94 | ||
|
|
ac03915dc3 | ||
|
|
31bc14b350 | ||
|
|
52cee573ad | ||
|
|
cb0444b1b5 | ||
|
|
356ad4fe3a | ||
|
|
70b4d282c6 | ||
|
|
48dbc61129 | ||
|
|
478b6b20a1 | ||
|
|
72c5480dfb | ||
|
|
00f70c30a6 | ||
|
|
9aa40871c2 | ||
|
|
a7ac5a6bca | ||
|
|
9f283f330b | ||
|
|
e0b2a94309 | ||
|
|
2e88c86f10 | ||
|
|
bd5b3c2ac0 | ||
|
|
91a9ae42d2 | ||
|
|
799a2ae311 | ||
|
|
a97e411b44 | ||
|
|
f02782a6f2 | ||
|
|
6fe89ea00f | ||
|
|
0b279f4ad4 | ||
|
|
4e55b83101 | ||
|
|
3b4f26e4d1 | ||
|
|
df15be3fad | ||
|
|
9d7e038bcb | ||
|
|
33a90f2dd2 | ||
|
|
bec864a78c | ||
|
|
897a38978d | ||
|
|
601c29ca73 | ||
|
|
76ec820465 | ||
|
|
cfe53e7425 | ||
|
|
4a98f190a8 | ||
|
|
c1cfb61b1b | ||
|
|
50c2b82f24 | ||
|
|
27064f95c7 | ||
|
|
5da7879b38 | ||
|
|
22c6e8a424 | ||
|
|
cb5cd4376e |
284
.cmake-format.json
Normal file
284
.cmake-format.json
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
{
|
||||||
|
"_help_parse": "Options affecting listfile parsing",
|
||||||
|
"parse": {
|
||||||
|
"_help_additional_commands": [
|
||||||
|
"Specify structure for custom cmake functions"
|
||||||
|
],
|
||||||
|
"additional_commands": {
|
||||||
|
"filter_source_cuda_architectures": {
|
||||||
|
"flags": [
|
||||||
|
"IMPLICIT_FAMILY"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"SOURCE_LIST": "1",
|
||||||
|
"TARGET": "1",
|
||||||
|
"ARCHS": "+"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_help_vartags": [
|
||||||
|
"Specify variable tags."
|
||||||
|
],
|
||||||
|
"vartags": [],
|
||||||
|
"_help_proptags": [
|
||||||
|
"Specify property tags."
|
||||||
|
],
|
||||||
|
"proptags": []
|
||||||
|
},
|
||||||
|
"_help_format": "Options affecting formatting.",
|
||||||
|
"format": {
|
||||||
|
"_help_line_width": [
|
||||||
|
"How wide to allow formatted cmake files"
|
||||||
|
],
|
||||||
|
"line_width": 80,
|
||||||
|
"_help_tab_size": [
|
||||||
|
"How many spaces to tab for indent"
|
||||||
|
],
|
||||||
|
"tab_size": 2,
|
||||||
|
"_help_max_subgroups_hwrap": [
|
||||||
|
"If an argument group contains more than this many sub-groups",
|
||||||
|
"(parg or kwarg groups) then force it to a vertical layout."
|
||||||
|
],
|
||||||
|
"max_subgroups_hwrap": 2,
|
||||||
|
"_help_max_pargs_hwrap": [
|
||||||
|
"If a positional argument group contains more than this many",
|
||||||
|
"arguments, then force it to a vertical layout."
|
||||||
|
],
|
||||||
|
"max_pargs_hwrap": 6,
|
||||||
|
"_help_max_rows_cmdline": [
|
||||||
|
"If a cmdline positional group consumes more than this many",
|
||||||
|
"lines without nesting, then invalidate the layout (and nest)"
|
||||||
|
],
|
||||||
|
"max_rows_cmdline": 2,
|
||||||
|
"_help_separate_ctrl_name_with_space": [
|
||||||
|
"If true, separate flow control names from their parentheses",
|
||||||
|
"with a space"
|
||||||
|
],
|
||||||
|
"separate_ctrl_name_with_space": false,
|
||||||
|
"_help_separate_fn_name_with_space": [
|
||||||
|
"If true, separate function names from parentheses with a",
|
||||||
|
"space"
|
||||||
|
],
|
||||||
|
"separate_fn_name_with_space": false,
|
||||||
|
"_help_dangle_parens": [
|
||||||
|
"If a statement is wrapped to more than one line, than dangle",
|
||||||
|
"the closing parenthesis on its own line."
|
||||||
|
],
|
||||||
|
"dangle_parens": false,
|
||||||
|
"_help_dangle_align": [
|
||||||
|
"If the trailing parenthesis must be 'dangled' on its on",
|
||||||
|
"line, then align it to this reference: `prefix`: the start",
|
||||||
|
"of the statement, `prefix-indent`: the start of the",
|
||||||
|
"statement, plus one indentation level, `child`: align to",
|
||||||
|
"the column of the arguments"
|
||||||
|
],
|
||||||
|
"dangle_align": "prefix",
|
||||||
|
"_help_min_prefix_chars": [
|
||||||
|
"If the statement spelling length (including space and",
|
||||||
|
"parenthesis) is smaller than this amount, then force reject",
|
||||||
|
"nested layouts."
|
||||||
|
],
|
||||||
|
"min_prefix_chars": 4,
|
||||||
|
"_help_max_prefix_chars": [
|
||||||
|
"If the statement spelling length (including space and",
|
||||||
|
"parenthesis) is larger than the tab width by more than this",
|
||||||
|
"amount, then force reject un-nested layouts."
|
||||||
|
],
|
||||||
|
"max_prefix_chars": 10,
|
||||||
|
"_help_max_lines_hwrap": [
|
||||||
|
"If a candidate layout is wrapped horizontally but it exceeds",
|
||||||
|
"this many lines, then reject the layout."
|
||||||
|
],
|
||||||
|
"max_lines_hwrap": 2,
|
||||||
|
"_help_line_ending": [
|
||||||
|
"What style line endings to use in the output."
|
||||||
|
],
|
||||||
|
"line_ending": "unix",
|
||||||
|
"_help_command_case": [
|
||||||
|
"Format command names consistently as 'lower' or 'upper' case"
|
||||||
|
],
|
||||||
|
"command_case": "canonical",
|
||||||
|
"_help_keyword_case": [
|
||||||
|
"Format keywords consistently as 'lower' or 'upper' case"
|
||||||
|
],
|
||||||
|
"keyword_case": "unchanged",
|
||||||
|
"_help_always_wrap": [
|
||||||
|
"A list of command names which should always be wrapped"
|
||||||
|
],
|
||||||
|
"always_wrap": [],
|
||||||
|
"_help_enable_sort": [
|
||||||
|
"If true, the argument lists which are known to be sortable",
|
||||||
|
"will be sorted lexicographicall"
|
||||||
|
],
|
||||||
|
"enable_sort": true,
|
||||||
|
"_help_autosort": [
|
||||||
|
"If true, the parsers may infer whether or not an argument",
|
||||||
|
"list is sortable (without annotation)."
|
||||||
|
],
|
||||||
|
"autosort": false,
|
||||||
|
"_help_require_valid_layout": [
|
||||||
|
"By default, if cmake-format cannot successfully fit",
|
||||||
|
"everything into the desired linewidth it will apply the",
|
||||||
|
"last, most agressive attempt that it made. If this flag is",
|
||||||
|
"True, however, cmake-format will print error, exit with non-",
|
||||||
|
"zero status code, and write-out nothing"
|
||||||
|
],
|
||||||
|
"require_valid_layout": false,
|
||||||
|
"_help_layout_passes": [
|
||||||
|
"A dictionary mapping layout nodes to a list of wrap",
|
||||||
|
"decisions. See the documentation for more information."
|
||||||
|
],
|
||||||
|
"layout_passes": {}
|
||||||
|
},
|
||||||
|
"_help_markup": "Options affecting comment reflow and formatting.",
|
||||||
|
"markup": {
|
||||||
|
"_help_bullet_char": [
|
||||||
|
"What character to use for bulleted lists"
|
||||||
|
],
|
||||||
|
"bullet_char": "*",
|
||||||
|
"_help_enum_char": [
|
||||||
|
"What character to use as punctuation after numerals in an",
|
||||||
|
"enumerated list"
|
||||||
|
],
|
||||||
|
"enum_char": ".",
|
||||||
|
"_help_first_comment_is_literal": [
|
||||||
|
"If comment markup is enabled, don't reflow the first comment",
|
||||||
|
"block in each listfile. Use this to preserve formatting of",
|
||||||
|
"your copyright/license statements."
|
||||||
|
],
|
||||||
|
"first_comment_is_literal": false,
|
||||||
|
"_help_literal_comment_pattern": [
|
||||||
|
"If comment markup is enabled, don't reflow any comment block",
|
||||||
|
"which matches this (regex) pattern. Default is `None`",
|
||||||
|
"(disabled)."
|
||||||
|
],
|
||||||
|
"literal_comment_pattern": null,
|
||||||
|
"_help_fence_pattern": [
|
||||||
|
"Regular expression to match preformat fences in comments",
|
||||||
|
"default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``"
|
||||||
|
],
|
||||||
|
"fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$",
|
||||||
|
"_help_ruler_pattern": [
|
||||||
|
"Regular expression to match rulers in comments default=",
|
||||||
|
"``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``"
|
||||||
|
],
|
||||||
|
"ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$",
|
||||||
|
"_help_explicit_trailing_pattern": [
|
||||||
|
"If a comment line matches starts with this pattern then it",
|
||||||
|
"is explicitly a trailing comment for the preceeding",
|
||||||
|
"argument. Default is '#<'"
|
||||||
|
],
|
||||||
|
"explicit_trailing_pattern": "#<",
|
||||||
|
"_help_hashruler_min_length": [
|
||||||
|
"If a comment line starts with at least this many consecutive",
|
||||||
|
"hash characters, then don't lstrip() them off. This allows",
|
||||||
|
"for lazy hash rulers where the first hash char is not",
|
||||||
|
"separated by space"
|
||||||
|
],
|
||||||
|
"hashruler_min_length": 10,
|
||||||
|
"_help_canonicalize_hashrulers": [
|
||||||
|
"If true, then insert a space between the first hash char and",
|
||||||
|
"remaining hash chars in a hash ruler, and normalize its",
|
||||||
|
"length to fill the column"
|
||||||
|
],
|
||||||
|
"canonicalize_hashrulers": true,
|
||||||
|
"_help_enable_markup": [
|
||||||
|
"enable comment markup parsing and reflow"
|
||||||
|
],
|
||||||
|
"enable_markup": true
|
||||||
|
},
|
||||||
|
"_help_lint": "Options affecting the linter",
|
||||||
|
"lint": {
|
||||||
|
"_help_disabled_codes": [
|
||||||
|
"a list of lint codes to disable"
|
||||||
|
],
|
||||||
|
"disabled_codes": [],
|
||||||
|
"_help_function_pattern": [
|
||||||
|
"regular expression pattern describing valid function names"
|
||||||
|
],
|
||||||
|
"function_pattern": "[0-9a-z_]+",
|
||||||
|
"_help_macro_pattern": [
|
||||||
|
"regular expression pattern describing valid macro names"
|
||||||
|
],
|
||||||
|
"macro_pattern": "[0-9A-Z_]+",
|
||||||
|
"_help_global_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with global (cache) scope"
|
||||||
|
],
|
||||||
|
"global_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_internal_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with global scope (but internal semantic)"
|
||||||
|
],
|
||||||
|
"internal_var_pattern": "_[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_local_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with local scope"
|
||||||
|
],
|
||||||
|
"local_var_pattern": "[a-z][a-z0-9_]+",
|
||||||
|
"_help_private_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"privatedirectory variables"
|
||||||
|
],
|
||||||
|
"private_var_pattern": "_[0-9a-z_]+",
|
||||||
|
"_help_public_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for public",
|
||||||
|
"directory variables"
|
||||||
|
],
|
||||||
|
"public_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_argument_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"function/macro arguments and loop variables."
|
||||||
|
],
|
||||||
|
"argument_var_pattern": "[a-z][a-z0-9_]+",
|
||||||
|
"_help_keyword_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"keywords used in functions or macros"
|
||||||
|
],
|
||||||
|
"keyword_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_max_conditionals_custom_parser": [
|
||||||
|
"In the heuristic for C0201, how many conditionals to match",
|
||||||
|
"within a loop in before considering the loop a parser."
|
||||||
|
],
|
||||||
|
"max_conditionals_custom_parser": 2,
|
||||||
|
"_help_min_statement_spacing": [
|
||||||
|
"Require at least this many newlines between statements"
|
||||||
|
],
|
||||||
|
"min_statement_spacing": 1,
|
||||||
|
"_help_max_statement_spacing": [
|
||||||
|
"Require no more than this many newlines between statements"
|
||||||
|
],
|
||||||
|
"max_statement_spacing": 2,
|
||||||
|
"max_returns": 6,
|
||||||
|
"max_branches": 12,
|
||||||
|
"max_arguments": 5,
|
||||||
|
"max_localvars": 15,
|
||||||
|
"max_statements": 50
|
||||||
|
},
|
||||||
|
"_help_encode": "Options affecting file encoding",
|
||||||
|
"encode": {
|
||||||
|
"_help_emit_byteorder_mark": [
|
||||||
|
"If true, emit the unicode byte-order mark (BOM) at the start",
|
||||||
|
"of the file"
|
||||||
|
],
|
||||||
|
"emit_byteorder_mark": false,
|
||||||
|
"_help_input_encoding": [
|
||||||
|
"Specify the encoding of the input file. Defaults to utf-8"
|
||||||
|
],
|
||||||
|
"input_encoding": "utf-8",
|
||||||
|
"_help_output_encoding": [
|
||||||
|
"Specify the encoding of the output file. Defaults to utf-8.",
|
||||||
|
"Note that cmake only claims to support utf-8 so be careful",
|
||||||
|
"when using anything else"
|
||||||
|
],
|
||||||
|
"output_encoding": "utf-8"
|
||||||
|
},
|
||||||
|
"_help_misc": "Miscellaneous configurations options.",
|
||||||
|
"misc": {
|
||||||
|
"_help_per_command": [
|
||||||
|
"A dictionary containing any per-command configuration",
|
||||||
|
"overrides. Currently only `command_case` is supported."
|
||||||
|
],
|
||||||
|
"per_command": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -5,4 +5,4 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
# Uncomment the following lines to enable
|
# Uncomment the following lines to enable
|
||||||
# # Mount TRTLLM data volume:
|
# # Mount TRTLLM data volume:
|
||||||
# - /home/scratch.trt_llm_data/:/home/scratch.trt_llm_data/:ro
|
# - /home/scratch.trt_llm_data_ci/:/home/scratch.trt_llm_data_ci/:ro
|
||||||
|
|||||||
31
.github/CODEOWNERS
vendored
31
.github/CODEOWNERS
vendored
@ -1,5 +1,18 @@
|
|||||||
# This file defines code ownership rules for the repository.
|
# This file defines code ownership rules for the repository.
|
||||||
|
|
||||||
|
## TensorRT-LLM QA
|
||||||
|
### Integration Tests
|
||||||
|
/tests/integration/test_lists/qa @NVIDIA/trt-llm-qa
|
||||||
|
/tests/integration/defs/examples/test_ray.py @NVIDIA/trt-llm-qa-function
|
||||||
|
/tests/integration/defs/examples/test_redrafter.py @NVIDIA/trt-llm-qa-function
|
||||||
|
/tests/integration/defs/accuracy @NVIDIA/trt-llm-qa-function
|
||||||
|
/tests/integration/defs/stress_test @NVIDIA/trt-llm-qa-function
|
||||||
|
/tests/integration/defs/triton_server @NVIDIA/trt-llm-qa-function
|
||||||
|
/tests/integration/defs/test_e2e.py @NVIDIA/trt-llm-qa-function
|
||||||
|
/tests/integration/defs/disaggregated @NVIDIA/trt-llm-qa-serving
|
||||||
|
/tests/integration/defs/sysinfo @NVIDIA/trt-llm-qa-perf
|
||||||
|
/tests/integration/defs/perf @NVIDIA/trt-llm-qa-perf
|
||||||
|
/tests/integration/defs/perf/disagg @NVIDIA/trt-llm-qa-serving
|
||||||
|
|
||||||
## TensorRT-LLM Infra
|
## TensorRT-LLM Infra
|
||||||
### CI
|
### CI
|
||||||
@ -13,6 +26,13 @@
|
|||||||
|
|
||||||
## TensorRT-LLM - Docs
|
## TensorRT-LLM - Docs
|
||||||
/docs @NVIDIA/trt-llm-doc-owners
|
/docs @NVIDIA/trt-llm-doc-owners
|
||||||
|
/CODING_GUIDELINES.md @NVIDIA/trt-llm-doc-owners
|
||||||
|
/CODE_OF_CONDUCT.md @NVIDIA/trt-llm-doc-owners
|
||||||
|
/CONTAINER_SOURCE.md @NVIDIA/trt-llm-doc-owners
|
||||||
|
/CONTRIBUTING.md @NVIDIA/trt-llm-doc-owners
|
||||||
|
/README.md @NVIDIA/trt-llm-doc-owners
|
||||||
|
/CLAUDE.md @NVIDIA/trt-llm-doc-owners
|
||||||
|
/AGENTS.md @NVIDIA/trt-llm-doc-owners
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
/examples @NVIDIA/trt-llm-doc-owners
|
/examples @NVIDIA/trt-llm-doc-owners
|
||||||
@ -136,6 +156,8 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
|||||||
## TensorRT-LLM LLM API
|
## TensorRT-LLM LLM API
|
||||||
/tensorrt_llm/llmapi @NVIDIA/trt-llm-llmapi-devs
|
/tensorrt_llm/llmapi @NVIDIA/trt-llm-llmapi-devs
|
||||||
/tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs
|
/tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs
|
||||||
|
/tensorrt_llm/serve @NVIDIA/trt-llm-llmapi-devs
|
||||||
|
/tensorrt_llm/commands @NVIDIA/trt-llm-llmapi-devs
|
||||||
|
|
||||||
## TensorRT-LLM LLM Disaggregated
|
## TensorRT-LLM LLM Disaggregated
|
||||||
/examples/disaggregated @NVIDIA/trt-llm-disagg-devs @NVIDIA/trt-llm-doc-owners
|
/examples/disaggregated @NVIDIA/trt-llm-disagg-devs @NVIDIA/trt-llm-doc-owners
|
||||||
@ -167,8 +189,6 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
|||||||
/tensorrt_llm/_torch/pyexecutor/resource_manager.py @NVIDIA/trt-llm-kv-cache-manager-devs
|
/tensorrt_llm/_torch/pyexecutor/resource_manager.py @NVIDIA/trt-llm-kv-cache-manager-devs
|
||||||
/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h @NVIDIA/trt-llm-kv-cache-manager-devs
|
/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h @NVIDIA/trt-llm-kv-cache-manager-devs
|
||||||
/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @NVIDIA/trt-llm-kv-cache-manager-devs
|
/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @NVIDIA/trt-llm-kv-cache-manager-devs
|
||||||
/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.h @NVIDIA/trt-llm-kv-cache-manager-devs
|
|
||||||
/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @NVIDIA/trt-llm-kv-cache-manager-devs
|
|
||||||
|
|
||||||
# The rule below requires that any PR modifying public APIs must be approved by at least one member
|
# The rule below requires that any PR modifying public APIs must be approved by at least one member
|
||||||
# of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team.
|
# of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team.
|
||||||
@ -183,6 +203,7 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
|||||||
## and license compliance when adding, removing, or changing versions of dependencies.
|
## and license compliance when adding, removing, or changing versions of dependencies.
|
||||||
### License Files
|
### License Files
|
||||||
/LICENSE @NVIDIA/trt-llm-oss-compliance
|
/LICENSE @NVIDIA/trt-llm-oss-compliance
|
||||||
|
/ATTRIBUTIONS-*.md @NVIDIA/trt-llm-oss-compliance
|
||||||
/jenkins/license_cpp.json @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance
|
/jenkins/license_cpp.json @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance
|
||||||
|
|
||||||
### Python Dependency Management
|
### Python Dependency Management
|
||||||
@ -200,6 +221,12 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
|||||||
## Any changes to versions, additions, or removals of third-party libraries
|
## Any changes to versions, additions, or removals of third-party libraries
|
||||||
/3rdparty/** @NVIDIA/trt-llm-oss-compliance
|
/3rdparty/** @NVIDIA/trt-llm-oss-compliance
|
||||||
|
|
||||||
|
### Vendored Third-Party Code (triton-kernels)
|
||||||
|
## This is a temporary vendored copy of triton-kernels from the Triton project (MIT License).
|
||||||
|
## Do not accept contributions to this directory - it should only be updated via scripts/vendor_triton_kernels.py
|
||||||
|
## This can be removed if and when triton-kernels is published as a separate wheel.
|
||||||
|
/triton_kernels/** @NVIDIA/trt-llm-oss-compliance
|
||||||
|
|
||||||
### Docker & Installation Scripts
|
### Docker & Installation Scripts
|
||||||
## These scripts install and pin dependency versions
|
## These scripts install and pin dependency versions
|
||||||
/docker/common/** @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance
|
/docker/common/** @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance
|
||||||
|
|||||||
4
.github/tava_architecture_diagram.md
vendored
4
.github/tava_architecture_diagram.md
vendored
@ -55,8 +55,8 @@ graph TB
|
|||||||
Sampling[Sampling]
|
Sampling[Sampling]
|
||||||
BatchManager[Batch Manager]
|
BatchManager[Batch Manager]
|
||||||
KVCache[KV Cache Manager]
|
KVCache[KV Cache Manager]
|
||||||
PyScheduler --> |Pybind|Shared_Scheduler
|
PyScheduler --> |Nanobind|Shared_Scheduler
|
||||||
PyDecoder --> |Pybind|Shared_Decoder
|
PyDecoder --> |Nanobind|Shared_Decoder
|
||||||
Executor --> Shared_Decoder
|
Executor --> Shared_Decoder
|
||||||
Shared_Decoder --> Sampling
|
Shared_Decoder --> Sampling
|
||||||
Executor --> Shared_Scheduler[Scheduler]
|
Executor --> Shared_Scheduler[Scheduler]
|
||||||
|
|||||||
4
.github/workflows/auto-assign.yml
vendored
4
.github/workflows/auto-assign.yml
vendored
@ -11,10 +11,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Get assignee
|
- name: Get assignee
|
||||||
uses: actions/github-script@v6
|
uses: actions/github-script@v8
|
||||||
id: get-assignee
|
id: get-assignee
|
||||||
with:
|
with:
|
||||||
github-token: ${{secrets.GITHUB_TOKEN}}
|
github-token: ${{secrets.GITHUB_TOKEN}}
|
||||||
|
|||||||
@ -14,7 +14,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v9
|
- uses: actions/stale@v10
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: 'Issue has not received an update in over 14 days. Adding stale label.'
|
stale-issue-message: 'Issue has not received an update in over 14 days. Adding stale label.'
|
||||||
|
|||||||
27
.github/workflows/blossom-ci.yml
vendored
27
.github/workflows/blossom-ci.yml
vendored
@ -53,6 +53,7 @@ jobs:
|
|||||||
"amukkara",
|
"amukkara",
|
||||||
"anish-shanbhag",
|
"anish-shanbhag",
|
||||||
"arekay",
|
"arekay",
|
||||||
|
"arysef",
|
||||||
"atrifex",
|
"atrifex",
|
||||||
"Autumn1998",
|
"Autumn1998",
|
||||||
"baize97",
|
"baize97",
|
||||||
@ -62,19 +63,22 @@ jobs:
|
|||||||
"BatshevaBlack",
|
"BatshevaBlack",
|
||||||
"benzh-2025",
|
"benzh-2025",
|
||||||
"BestJuly",
|
"BestJuly",
|
||||||
|
"binghanc",
|
||||||
|
"bmarimuthu-nv",
|
||||||
"bo-nv",
|
"bo-nv",
|
||||||
"bobboli",
|
"bobboli",
|
||||||
"Boreas618",
|
"Boreas618",
|
||||||
"brb-nv",
|
"brb-nv",
|
||||||
"byshiue",
|
"byshiue",
|
||||||
"CarstyYou",
|
"CarstyYou",
|
||||||
|
"cascade812",
|
||||||
"chang-l",
|
"chang-l",
|
||||||
"chenfeiz0326",
|
"chenfeiz0326",
|
||||||
"cherichy",
|
"cherichy",
|
||||||
"cheshirekow",
|
"cheshirekow",
|
||||||
|
"chienchunhung",
|
||||||
"ChristinaZ",
|
"ChristinaZ",
|
||||||
"chuangz0",
|
"chuangz0",
|
||||||
"ChunhuanLin",
|
|
||||||
"chzblych",
|
"chzblych",
|
||||||
"cjluo-nv",
|
"cjluo-nv",
|
||||||
"crazydemo",
|
"crazydemo",
|
||||||
@ -94,19 +98,21 @@ jobs:
|
|||||||
"dongfengy",
|
"dongfengy",
|
||||||
"dongjiyingdjy",
|
"dongjiyingdjy",
|
||||||
"dongxuy04",
|
"dongxuy04",
|
||||||
|
"dpitman-nvda",
|
||||||
"DylanChen-NV",
|
"DylanChen-NV",
|
||||||
"ebarilanM",
|
"ebarilanM",
|
||||||
"elvischenv",
|
"elvischenv",
|
||||||
"EmmaQiaoCh",
|
"EmmaQiaoCh",
|
||||||
"eopXD",
|
"eopXD",
|
||||||
|
"esha-nvidia",
|
||||||
"evezhier",
|
"evezhier",
|
||||||
"faradawn",
|
"faradawn",
|
||||||
"farazkh80",
|
"farazkh80",
|
||||||
"FelixXidddd",
|
|
||||||
"flin3500",
|
"flin3500",
|
||||||
"FrankD412",
|
"FrankD412",
|
||||||
"fredricz-20070104",
|
"fredricz-20070104",
|
||||||
"Fridah-nv",
|
"Fridah-nv",
|
||||||
|
"fsaady",
|
||||||
"funatiq",
|
"funatiq",
|
||||||
"fzmu727",
|
"fzmu727",
|
||||||
"galagam",
|
"galagam",
|
||||||
@ -121,6 +127,7 @@ jobs:
|
|||||||
"heyuhhh",
|
"heyuhhh",
|
||||||
"hijkzzz",
|
"hijkzzz",
|
||||||
"hlu1",
|
"hlu1",
|
||||||
|
"hnover-nv",
|
||||||
"HuiGao-NV",
|
"HuiGao-NV",
|
||||||
"hvagadia",
|
"hvagadia",
|
||||||
"hypdeb",
|
"hypdeb",
|
||||||
@ -142,6 +149,7 @@ jobs:
|
|||||||
"Jie-Fang",
|
"Jie-Fang",
|
||||||
"jiefangz-nv",
|
"jiefangz-nv",
|
||||||
"jieli-matrix",
|
"jieli-matrix",
|
||||||
|
"JintaoPengCS",
|
||||||
"jinyangyuan-nvidia",
|
"jinyangyuan-nvidia",
|
||||||
"jinzh-nvidia",
|
"jinzh-nvidia",
|
||||||
"jmydurant",
|
"jmydurant",
|
||||||
@ -154,6 +162,7 @@ jobs:
|
|||||||
"kaiyux",
|
"kaiyux",
|
||||||
"kanghui0204",
|
"kanghui0204",
|
||||||
"karljang",
|
"karljang",
|
||||||
|
"karthikvetrivel",
|
||||||
"katec846",
|
"katec846",
|
||||||
"Kefeng-Duan",
|
"Kefeng-Duan",
|
||||||
"KingsleyLiu-NV",
|
"KingsleyLiu-NV",
|
||||||
@ -172,6 +181,7 @@ jobs:
|
|||||||
"linda-stadter",
|
"linda-stadter",
|
||||||
"lingjiew",
|
"lingjiew",
|
||||||
"LinPoly",
|
"LinPoly",
|
||||||
|
"lirundong",
|
||||||
"litaotju",
|
"litaotju",
|
||||||
"liyuhannnnn",
|
"liyuhannnnn",
|
||||||
"lkomali",
|
"lkomali",
|
||||||
@ -179,6 +189,8 @@ jobs:
|
|||||||
"lowsfer",
|
"lowsfer",
|
||||||
"lucaslie",
|
"lucaslie",
|
||||||
"lucifer1004",
|
"lucifer1004",
|
||||||
|
"luyiyun1021",
|
||||||
|
"marinayanov",
|
||||||
"MartinMarciniszyn",
|
"MartinMarciniszyn",
|
||||||
"MatthiasKohl",
|
"MatthiasKohl",
|
||||||
"mayani-nv",
|
"mayani-nv",
|
||||||
@ -191,14 +203,17 @@ jobs:
|
|||||||
"mlefeb01",
|
"mlefeb01",
|
||||||
"moraxu",
|
"moraxu",
|
||||||
"MrGeva",
|
"MrGeva",
|
||||||
|
"mzweilz",
|
||||||
"Naveassaf",
|
"Naveassaf",
|
||||||
"nekorobov",
|
"nekorobov",
|
||||||
"netanel-haber",
|
"netanel-haber",
|
||||||
"niukuo",
|
"niukuo",
|
||||||
"Njuapp",
|
"Njuapp",
|
||||||
|
"nv-ananjappa",
|
||||||
"nv-guomingz",
|
"nv-guomingz",
|
||||||
"nv-lschneider",
|
"nv-lschneider",
|
||||||
"nv-yilinf",
|
"nv-yilinf",
|
||||||
|
"nv-yna",
|
||||||
"nvamyt",
|
"nvamyt",
|
||||||
"nvbrantz",
|
"nvbrantz",
|
||||||
"nvchenghaoz",
|
"nvchenghaoz",
|
||||||
@ -215,8 +230,10 @@ jobs:
|
|||||||
"omera-nv",
|
"omera-nv",
|
||||||
"pamelap-nvidia",
|
"pamelap-nvidia",
|
||||||
"pcastonguay",
|
"pcastonguay",
|
||||||
|
"pcicotti",
|
||||||
"pdrake-nv",
|
"pdrake-nv",
|
||||||
"peaceh-nv",
|
"peaceh-nv",
|
||||||
|
"peihu-nv",
|
||||||
"pengbowang-nv",
|
"pengbowang-nv",
|
||||||
"PerkzZheng",
|
"PerkzZheng",
|
||||||
"poweiw",
|
"poweiw",
|
||||||
@ -243,6 +260,7 @@ jobs:
|
|||||||
"schetlur-nv",
|
"schetlur-nv",
|
||||||
"shaharmor98",
|
"shaharmor98",
|
||||||
"shangz-ai",
|
"shangz-ai",
|
||||||
|
"sherry-1001",
|
||||||
"shifangx",
|
"shifangx",
|
||||||
"Shixiaowei02",
|
"Shixiaowei02",
|
||||||
"Shunkangz",
|
"Shunkangz",
|
||||||
@ -262,6 +280,7 @@ jobs:
|
|||||||
"syuoni",
|
"syuoni",
|
||||||
"Tabrizian",
|
"Tabrizian",
|
||||||
"talorabr",
|
"talorabr",
|
||||||
|
"taylor-yb-lee",
|
||||||
"tburt-nv",
|
"tburt-nv",
|
||||||
"tcherckez-nvidia",
|
"tcherckez-nvidia",
|
||||||
"thorjohnsen",
|
"thorjohnsen",
|
||||||
@ -283,7 +302,6 @@ jobs:
|
|||||||
"vegaluisjose",
|
"vegaluisjose",
|
||||||
"venkywonka",
|
"venkywonka",
|
||||||
"viraatc",
|
"viraatc",
|
||||||
"wangsiping1997",
|
|
||||||
"Wanli-Jiang",
|
"Wanli-Jiang",
|
||||||
"WeiHaocheng",
|
"WeiHaocheng",
|
||||||
"weireweire",
|
"weireweire",
|
||||||
@ -294,9 +312,11 @@ jobs:
|
|||||||
"wu6u3tw",
|
"wu6u3tw",
|
||||||
"wyw1267",
|
"wyw1267",
|
||||||
"xavier-nvidia",
|
"xavier-nvidia",
|
||||||
|
"xd-nv",
|
||||||
"xiaoweiw-nv",
|
"xiaoweiw-nv",
|
||||||
"xinhe-nv",
|
"xinhe-nv",
|
||||||
"xmchen1987",
|
"xmchen1987",
|
||||||
|
"xrq-phys",
|
||||||
"xuanzic",
|
"xuanzic",
|
||||||
"xueweilnvidia",
|
"xueweilnvidia",
|
||||||
"xupinjie",
|
"xupinjie",
|
||||||
@ -307,6 +327,7 @@ jobs:
|
|||||||
"yibinl-nvidia",
|
"yibinl-nvidia",
|
||||||
"yifeizhang-c",
|
"yifeizhang-c",
|
||||||
"yihwang-nv",
|
"yihwang-nv",
|
||||||
|
"yijingl-nvidia",
|
||||||
"yilin-void",
|
"yilin-void",
|
||||||
"yingcanw",
|
"yingcanw",
|
||||||
"yingguo-trt",
|
"yingguo-trt",
|
||||||
|
|||||||
5
.github/workflows/bot-command.yml
vendored
5
.github/workflows/bot-command.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Add bot help comment
|
- name: Add bot help comment
|
||||||
uses: actions/github-script@v6
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const helpMessage = "" +
|
const helpMessage = "" +
|
||||||
@ -46,7 +46,7 @@ jobs:
|
|||||||
"Run `/bot [-h|--help]` to print this help message.\n\n" +
|
"Run `/bot [-h|--help]` to print this help message.\n\n" +
|
||||||
"See details below for each supported subcommand.\n\n" +
|
"See details below for each supported subcommand.\n\n" +
|
||||||
"<details>\n\n" +
|
"<details>\n\n" +
|
||||||
"`run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list \"A10-PyTorch-1, xxx\" --gpu-type \"A30, H100_PCIe\" --test-backend \"pytorch, cpp\" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage \"H100_PCIe-TensorRT-Post-Merge-1, xxx\" --detailed-log --debug(experimental)]`\n\n" +
|
"`run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list \"A10-PyTorch-1, xxx\" --gpu-type \"A30, H100_PCIe\" --test-backend \"pytorch, cpp\" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage \"H100_PCIe-TensorRT-Post-Merge-1, xxx\" --detailed-log --debug(experimental) --high-priority]`\n\n" +
|
||||||
"Launch build/test pipelines. All previously running jobs will be killed.\n\n" +
|
"Launch build/test pipelines. All previously running jobs will be killed.\n\n" +
|
||||||
"`--reuse-test (optional)pipeline-id ` *(OPTIONAL)* : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.\n\n" +
|
"`--reuse-test (optional)pipeline-id ` *(OPTIONAL)* : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.\n\n" +
|
||||||
"`--disable-reuse-test ` *(OPTIONAL)* : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.\n\n" +
|
"`--disable-reuse-test ` *(OPTIONAL)* : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.\n\n" +
|
||||||
@ -62,6 +62,7 @@ jobs:
|
|||||||
"`--extra-stage \"H100_PCIe-TensorRT-Post-Merge-1, xxx\"` *(OPTIONAL)* : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage \"H100_PCIe-TensorRT-Post-Merge-1, xxx\".\n\n" +
|
"`--extra-stage \"H100_PCIe-TensorRT-Post-Merge-1, xxx\"` *(OPTIONAL)* : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage \"H100_PCIe-TensorRT-Post-Merge-1, xxx\".\n\n" +
|
||||||
"`--detailed-log ` *(OPTIONAL)* : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.\n\n" +
|
"`--detailed-log ` *(OPTIONAL)* : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.\n\n" +
|
||||||
"`--debug ` *(OPTIONAL)* : **Experimental feature**. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the `stage-list` parameter to access the appropriate container environment. Note: Does **NOT** update GitHub check status.\n\n" +
|
"`--debug ` *(OPTIONAL)* : **Experimental feature**. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the `stage-list` parameter to access the appropriate container environment. Note: Does **NOT** update GitHub check status.\n\n" +
|
||||||
|
"`--high-priority ` *(OPTIONAL)* : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.\n\n" +
|
||||||
"### kill\n\n" +
|
"### kill\n\n" +
|
||||||
"`kill `\n\n" +
|
"`kill `\n\n" +
|
||||||
"Kill all running builds associated with pull request.\n\n" +
|
"Kill all running builds associated with pull request.\n\n" +
|
||||||
|
|||||||
4
.github/workflows/l0-test.yml
vendored
4
.github/workflows/l0-test.yml
vendored
@ -34,7 +34,7 @@ jobs:
|
|||||||
if: github.event_name == 'workflow_dispatch'
|
if: github.event_name == 'workflow_dispatch'
|
||||||
steps:
|
steps:
|
||||||
- name: Update commit status
|
- name: Update commit status
|
||||||
uses: actions/github-script@v6
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
state = 'pending'
|
state = 'pending'
|
||||||
@ -60,7 +60,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
paths: results/**/results*.xml
|
paths: results/**/results*.xml
|
||||||
- name: Update commit status
|
- name: Update commit status
|
||||||
uses: actions/github-script@v6
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
github.rest.repos.createCommitStatus({
|
github.rest.repos.createCommitStatus({
|
||||||
|
|||||||
4
.github/workflows/label_community_pr.yml
vendored
4
.github/workflows/label_community_pr.yml
vendored
@ -17,10 +17,10 @@ jobs:
|
|||||||
if: github.repository == 'NVIDIA/TensorRT-LLM'
|
if: github.repository == 'NVIDIA/TensorRT-LLM'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.x'
|
python-version: '3.x'
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/label_issue.yml
vendored
2
.github/workflows/label_issue.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout private action repository
|
- name: Checkout private action repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
repository: NVIDIA/goggles_action
|
repository: NVIDIA/goggles_action
|
||||||
path: ./.github/actions/goggles_action # local path to store the action
|
path: ./.github/actions/goggles_action # local path to store the action
|
||||||
|
|||||||
4
.github/workflows/pr-check.yml
vendored
4
.github/workflows/pr-check.yml
vendored
@ -59,10 +59,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/precommit-check.yml
vendored
4
.github/workflows/precommit-check.yml
vendored
@ -29,11 +29,11 @@ jobs:
|
|||||||
name: Pre-commit Check
|
name: Pre-commit Check
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.ref || github.ref }}
|
ref: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.ref || github.ref }}
|
||||||
|
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
|
|||||||
13
.gitignore
vendored
13
.gitignore
vendored
@ -1,10 +1,13 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
.vscode
|
.vscode
|
||||||
|
.cursor
|
||||||
*.engine
|
*.engine
|
||||||
*.engine.config
|
*.engine.config
|
||||||
*.cache
|
*.cache
|
||||||
*.nsys-rep
|
*.nsys-rep
|
||||||
*.npy
|
*.npy
|
||||||
|
*.so
|
||||||
|
*.whl
|
||||||
.VSCodeCounter
|
.VSCodeCounter
|
||||||
cpp/build*
|
cpp/build*
|
||||||
cpp/Release
|
cpp/Release
|
||||||
@ -40,6 +43,8 @@ tensorrt_llm/libs
|
|||||||
tensorrt_llm/bindings.*.so
|
tensorrt_llm/bindings.*.so
|
||||||
tensorrt_llm/bindings.pyi
|
tensorrt_llm/bindings.pyi
|
||||||
tensorrt_llm/bindings/**/*.pyi
|
tensorrt_llm/bindings/**/*.pyi
|
||||||
|
tensorrt_llm/tensorrt_llm_transfer_agent_binding.*.so
|
||||||
|
tensorrt_llm/tensorrt_llm_transfer_agent_binding.pyi
|
||||||
tensorrt_llm/deep_ep/
|
tensorrt_llm/deep_ep/
|
||||||
tensorrt_llm/deep_ep_cpp_tllm.*.so
|
tensorrt_llm/deep_ep_cpp_tllm.*.so
|
||||||
tensorrt_llm/deep_ep_cpp_tllm.pyi
|
tensorrt_llm/deep_ep_cpp_tllm.pyi
|
||||||
@ -50,25 +55,29 @@ tensorrt_llm/pg_utils_bindings.*.so
|
|||||||
tensorrt_llm/flash_mla/
|
tensorrt_llm/flash_mla/
|
||||||
tensorrt_llm/flash_mla_cpp_tllm.*.so
|
tensorrt_llm/flash_mla_cpp_tllm.*.so
|
||||||
tensorrt_llm/flash_mla_cpp_tllm.pyi
|
tensorrt_llm/flash_mla_cpp_tllm.pyi
|
||||||
|
tensorrt_llm/runtime/kv_cache_manager_v2/**/*.so
|
||||||
|
**/*__mypyc*.so
|
||||||
tensorrt_llm/scripts
|
tensorrt_llm/scripts
|
||||||
*docs/cpp_docs*
|
*docs/cpp_docs*
|
||||||
*docs/source/_cpp_gen*
|
*docs/source/_cpp_gen*
|
||||||
docs/source/**/*.rst
|
docs/source/**/*.rst
|
||||||
!docs/source/examples/index.rst
|
!docs/source/examples/index.rst
|
||||||
!docs/source/deployment-guide/config_table.rst
|
!docs/source/deployment-guide/config_table.rst
|
||||||
!docs/source/deployment-guide/note_sections.rst
|
!docs/source/_includes/note_sections.rst
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
# Testing
|
# Testing
|
||||||
.coverage.*
|
.coverage.*
|
||||||
results_trt/
|
results_trt/
|
||||||
llm-test-workspace/
|
llm-test-workspace/
|
||||||
|
ad-test-workspace/
|
||||||
|
|
||||||
# build/debug
|
# build/debug
|
||||||
*.safetensors
|
*.safetensors
|
||||||
*/tllm_debug/**
|
*/tllm_debug/**
|
||||||
*.patch
|
*.patch
|
||||||
!cpp/tensorrt_llm/deep_ep/*.patch
|
!cpp/tensorrt_llm/deep_ep/*.patch
|
||||||
|
examples/disaggregated/slurm/benchmark/logs/
|
||||||
|
|
||||||
# Generated files
|
# Generated files
|
||||||
cpp/include/tensorrt_llm/executor/version.h
|
cpp/include/tensorrt_llm/executor/version.h
|
||||||
@ -76,9 +85,11 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
|
|||||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
|
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
|
||||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
|
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
|
||||||
.devcontainer/.env
|
.devcontainer/.env
|
||||||
|
/examples/layer_wise_benchmarks/autotuner_cache/
|
||||||
/examples/layer_wise_benchmarks/profiles/
|
/examples/layer_wise_benchmarks/profiles/
|
||||||
|
|
||||||
# User config files
|
# User config files
|
||||||
|
CLAUDE.local.md
|
||||||
CMakeUserPresets.json
|
CMakeUserPresets.json
|
||||||
compile_commands.json
|
compile_commands.json
|
||||||
*.bin
|
*.bin
|
||||||
|
|||||||
@ -951,7 +951,6 @@ common-files: &common_files |
|
|||||||
tests/unittest/_torch/attention/test_attention_no_cache.py |
|
tests/unittest/_torch/attention/test_attention_no_cache.py |
|
||||||
tests/unittest/_torch/attention/test_attention.py |
|
tests/unittest/_torch/attention/test_attention.py |
|
||||||
tests/unittest/_torch/attention/test_flashinfer_attention.py |
|
tests/unittest/_torch/attention/test_flashinfer_attention.py |
|
||||||
tests/unittest/_torch/attention/test_flashinfer_star_attn.py |
|
|
||||||
tests/unittest/_torch/attention/test_vanilla_attention.py |
|
tests/unittest/_torch/attention/test_vanilla_attention.py |
|
||||||
tests/unittest/_torch/compilation/test_add_norm.py |
|
tests/unittest/_torch/compilation/test_add_norm.py |
|
||||||
tests/unittest/_torch/debugger/test_debugger_addon.py |
|
tests/unittest/_torch/debugger/test_debugger_addon.py |
|
||||||
@ -1004,7 +1003,6 @@ common-files: &common_files |
|
|||||||
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py |
|
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py |
|
||||||
tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py |
|
tests/unittest/_torch/multi_gpu/test_mnnvl_memory.py |
|
||||||
tests/unittest/_torch/multi_gpu/test_moe_a2a.py |
|
tests/unittest/_torch/multi_gpu/test_moe_a2a.py |
|
||||||
tests/unittest/_torch/multi_gpu/test_star_attention.py |
|
|
||||||
tests/unittest/_torch/multi_gpu/test_user_buffers.py |
|
tests/unittest/_torch/multi_gpu/test_user_buffers.py |
|
||||||
tests/unittest/_torch/multimodal/test_external_embedding.py |
|
tests/unittest/_torch/multimodal/test_external_embedding.py |
|
||||||
tests/unittest/_torch/multimodal/test_find_num_image_tokens.py |
|
tests/unittest/_torch/multimodal/test_find_num_image_tokens.py |
|
||||||
@ -1022,7 +1020,6 @@ common-files: &common_files |
|
|||||||
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
|
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
|
||||||
tests/unittest/_torch/sampler/test_beam_search.py |
|
tests/unittest/_torch/sampler/test_beam_search.py |
|
||||||
tests/unittest/_torch/sampler/test_best_of_n.py |
|
tests/unittest/_torch/sampler/test_best_of_n.py |
|
||||||
tests/unittest/_torch/sampler/test_return_logits.py |
|
|
||||||
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
|
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
|
||||||
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
|
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
|
||||||
tests/unittest/_torch/speculative/test_draft_target.py |
|
tests/unittest/_torch/speculative/test_draft_target.py |
|
||||||
@ -1369,6 +1366,9 @@ common-files: &common_files |
|
|||||||
triton_backend/tools/whisper/client.py |
|
triton_backend/tools/whisper/client.py |
|
||||||
)$
|
)$
|
||||||
|
|
||||||
|
# Global exclude pattern for vendored third-party code
|
||||||
|
exclude: '^triton_kernels/'
|
||||||
|
|
||||||
default_install_hook_types: [pre-commit, commit-msg]
|
default_install_hook_types: [pre-commit, commit-msg]
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
@ -1398,7 +1398,7 @@ repos:
|
|||||||
exclude: |
|
exclude: |
|
||||||
(?x)^(.*cubin.cpp | .*cubin.h)$
|
(?x)^(.*cubin.cpp | .*cubin.h)$
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
args: [--allow-multiple-documents]
|
args: [--allow-multiple-documents, --unsafe]
|
||||||
exclude: ".*/gitlab/.*.yml"
|
exclude: ".*/gitlab/.*.yml"
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
exclude: '\.(patch|md)$'
|
exclude: '\.(patch|md)$'
|
||||||
|
|||||||
4
3rdparty/CMakeLists.txt
vendored
4
3rdparty/CMakeLists.txt
vendored
@ -38,8 +38,8 @@ FetchContent_Declare(
|
|||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
deepgemm
|
deepgemm
|
||||||
GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM
|
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
|
||||||
GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch
|
GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch
|
||||||
GIT_SUBMODULES_RECURSE
|
GIT_SUBMODULES_RECURSE
|
||||||
ON
|
ON
|
||||||
SOURCE_SUBDIR
|
SOURCE_SUBDIR
|
||||||
|
|||||||
148
AGENTS.md
Normal file
148
AGENTS.md
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
TensorRT-LLM: open-source library for optimized LLM inference on NVIDIA GPUs.
|
||||||
|
Python and C++ codebase supporting TensorRT engine-based and PyTorch-based execution paths.
|
||||||
|
|
||||||
|
> If a `CLAUDE.local.md` file exists alongside this file, read and respect it — it contains developer-specific overrides that supplement this shared guidance.
|
||||||
|
|
||||||
|
## Rules (Read First)
|
||||||
|
|
||||||
|
**CRITICAL (YOU MUST):**
|
||||||
|
- Read and follow `CODING_GUIDELINES.md` for ALL code changes (C++ and Python)
|
||||||
|
- NVIDIA copyright header on ALL new files (update year on modified files)
|
||||||
|
- `git commit -s` (DCO sign-off required). Never attribute AI tools in sign-off line
|
||||||
|
- `pre-commit` hooks run on commit — if files are modified by hooks, re-stage and commit again
|
||||||
|
- PR title format: `[JIRA/NVBUG/None][type] description` (e.g., `[TRTLLM-5516][perf] optimize cuda graph padding`)
|
||||||
|
- Python imports: `from package.subpackage import module` (never `from module import Class`)
|
||||||
|
- Set `LLM_MODELS_ROOT` env var when running tests that need model weights
|
||||||
|
|
||||||
|
## Common Commands
|
||||||
|
|
||||||
|
| Task | Command |
|
||||||
|
|------|---------|
|
||||||
|
| Unit tests | `pytest tests/unittest/` |
|
||||||
|
| Specific test | `pytest tests/unittest/llmapi/test_llm_args.py` |
|
||||||
|
| Pattern match | `pytest tests/unittest -k "test_llm_args"` |
|
||||||
|
| Integration tests | `LLM_MODELS_ROOT=/path/to/models pytest tests/integration/defs/...` |
|
||||||
|
| Serve model | `trtllm-serve --model <hf_model> --port 8000` |
|
||||||
|
| Serve with config | `trtllm-serve --model <hf_model> --config config.yaml` |
|
||||||
|
| Benchmark | `trtllm-bench --model <hf_model> --dataset_path <path>` |
|
||||||
|
| Find CI stage for test | `python scripts/test_to_stage_mapping.py --tests "test_name"` |
|
||||||
|
|
||||||
|
### Installation & Build
|
||||||
|
|
||||||
|
Building TensorRT-LLM requires Docker and may involve compiling C++ components.
|
||||||
|
See [build from source](docs/source/installation/build-from-source-linux.md) for full instructions,
|
||||||
|
or [pip install](docs/source/installation/linux.md) for pre-built wheels.
|
||||||
|
For container images, see [NGC containers](docs/source/installation/containers.md).
|
||||||
|
|
||||||
|
### Reference Configs
|
||||||
|
|
||||||
|
`examples/configs/database/` contains 170+ pareto-optimized serving configurations
|
||||||
|
across multiple models, GPUs, ISL/OSL combinations, and concurrency levels.
|
||||||
|
Use these as starting points for deployment and benchmarking rather than hand-tuning parameters.
|
||||||
|
See [deployment guides](docs/source/deployment-guide/) for model-specific walkthroughs.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
See [architecture diagram](.github/tava_architecture_diagram.md) for the full Mermaid diagram.
|
||||||
|
|
||||||
|
### Backends
|
||||||
|
|
||||||
|
| Backend | Status | Entry Point | Key Path |
|
||||||
|
|---------|--------|-------------|----------|
|
||||||
|
| **PyTorch** | Default | `LLM(backend="pytorch")` | `_torch/pyexecutor/` → `PyExecutor` → PyTorch Engine |
|
||||||
|
| **AutoDeploy** | Beta | `LLM(backend="_autodeploy")` | `_torch/auto_deploy/` → `ADExecutor` → graph transforms + torch.export |
|
||||||
|
| **TensorRT** | Legacy | `LLM(backend="tensorrt")` | `builder.py` → `trtllm.Executor` → TensorRT Engine |
|
||||||
|
|
||||||
|
### Shared C++ Core (via Nanobind)
|
||||||
|
|
||||||
|
Both PyTorch and TensorRT backends share these C++ components:
|
||||||
|
- **Scheduling pipeline**: Scheduler → BatchManager (in-flight batching) → KV Cache Manager
|
||||||
|
- **Decoding pipeline**: Decoder (token generation orchestration) → Sampling
|
||||||
|
|
||||||
|
### Request Flow
|
||||||
|
```text
|
||||||
|
HuggingFace Model → LLM API → Executor (PyTorch/AutoDeploy/TensorRT)
|
||||||
|
→ Scheduler → Model Forward → Decoder → Sampling → Generated Tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
### Serving
|
||||||
|
- `trtllm-serve`: OpenAI-compatible REST + gRPC server, supports all backends
|
||||||
|
- **Disaggregated serving**: separates prefill (context) and decode (generation) across GPUs
|
||||||
|
- KV cache exchange via NIXL (default), UCX, or MPI
|
||||||
|
|
||||||
|
## Key Files
|
||||||
|
|
||||||
|
| File | Role |
|
||||||
|
|------|------|
|
||||||
|
| `tensorrt_llm/llmapi/llm.py` | Main API entry point |
|
||||||
|
| `tensorrt_llm/llmapi/llm_args.py` | Complete configuration schema (Pydantic) |
|
||||||
|
| `tensorrt_llm/llmapi/llm_utils.py` | Model loading, model-specific default overrides |
|
||||||
|
| `tensorrt_llm/models/modeling_utils.py` | Base classes for all models (`PretrainedConfig`, `PretrainedModel`) |
|
||||||
|
| `tensorrt_llm/executor/executor.py` | Execution abstraction (`GenerationExecutor`) |
|
||||||
|
| `tensorrt_llm/models/automodel.py` | Auto-discovery and model registry |
|
||||||
|
|
||||||
|
## Design Patterns
|
||||||
|
|
||||||
|
| Pattern | Key Points |
|
||||||
|
|---------|------------|
|
||||||
|
| **Config hierarchy** | `LlmArgs` → `TrtLlmArgs` / `TorchLlmArgs`, model-specific defaults override generics, Pydantic validation |
|
||||||
|
| **Model architecture** | Each model: `Config` (inherits `PretrainedConfig`) + `ForCausalLM` (inherits `PretrainedModel`) |
|
||||||
|
| **Model defaults** | Architecture-specific overrides in `llm_utils.py` (attention kernels, quant, spec decoding, cache) |
|
||||||
|
| **Distributed execution** | Tensor/pipeline parallelism via `Mapping` class, multiple backends (MPI, Ray, RPC) |
|
||||||
|
| **Auto-discovery** | Models self-register via `automodel.py`, resolved by HF config `architectures` field |
|
||||||
|
|
||||||
|
## Anti-Patterns / Gotchas
|
||||||
|
|
||||||
|
- **Pre-commit modifies files in-place** — if hooks fail, files are already modified. Re-stage (`git add`) and commit again.
|
||||||
|
- **Protected APIs exist** — changes to LLM API signatures will fail `tests/api_stability` tests. Get code owner review.
|
||||||
|
- **Integration tests need GPUs + models** — always set `LLM_MODELS_ROOT` and ensure GPU access. Unit tests don't.
|
||||||
|
- **Copyright year** — update to current year when modifying existing files; add full header to new files.
|
||||||
|
- **Avoid broad exception handling** — catch specific exceptions, not bare `except:` (see `CODING_GUIDELINES.md`).
|
||||||
|
- **Python import style is enforced** — `from package.subpackage import module`, never `from module import Class`. Pre-commit will not catch this.
|
||||||
|
- **One concern per PR** — avoid scope creep. If a PR touches unrelated areas, split it.
|
||||||
|
|
||||||
|
## Development Workflow
|
||||||
|
|
||||||
|
1. Set up build environment (see [installation docs](docs/source/installation/))
|
||||||
|
2. Make changes following `CODING_GUIDELINES.md`
|
||||||
|
3. Test locally with `pytest`
|
||||||
|
4. Submit PR:
|
||||||
|
- PR title format: `[JIRA/NVBUG/None][type] description` (e.g., `[TRTLLM-5516][perf] optimize cuda graph padding`)
|
||||||
|
- Sign commits with DCO (`git commit -s`)
|
||||||
|
- Target `main` unless fixing a release branch bug
|
||||||
|
- See `CONTRIBUTING.md` for full PR policies
|
||||||
|
|
||||||
|
## CI / Testing
|
||||||
|
|
||||||
|
See [CI overview](docs/source/developer-guide/ci-overview.md) for full details.
|
||||||
|
|
||||||
|
| Layer | Location | Notes |
|
||||||
|
|-------|----------|-------|
|
||||||
|
| Unit tests | `tests/unittest/` | Run in pre-merge CI; some tests require GPU |
|
||||||
|
| API stability | `tests/api_stability/` | Protects committed API signatures |
|
||||||
|
| Integration tests | `tests/integration/defs/` | Requires GPU + `LLM_MODELS_ROOT` |
|
||||||
|
| Test lists | `tests/integration/test_lists/test-db/` | Per-GPU YAML files (`l0_a10.yml`, `l0_h100.yml`, etc.) |
|
||||||
|
| Test waives | `tests/integration/test_lists/waives.txt` | Skip known-failing tests with NVBug links |
|
||||||
|
| Performance | See [benchmarking guide](docs/source/developer-guide/perf-benchmarking.md) | `trtllm-bench` and `trtllm-serve` benchmarks |
|
||||||
|
|
||||||
|
## Key Documentation
|
||||||
|
|
||||||
|
| Topic | Path |
|
||||||
|
|-------|------|
|
||||||
|
| Architecture overview | `docs/source/developer-guide/overview.md` |
|
||||||
|
| PyTorch backend | `docs/source/torch/arch_overview.md` |
|
||||||
|
| Adding a new model | `docs/source/torch/adding_new_model.md` |
|
||||||
|
| AutoDeploy | `docs/source/features/auto_deploy/auto-deploy.md` |
|
||||||
|
| Disaggregated serving | `docs/source/features/disagg-serving.md` |
|
||||||
|
| Speculative decoding | `docs/source/features/speculative-decoding.md` |
|
||||||
|
| Quantization | `docs/source/features/quantization.md` |
|
||||||
|
| Parallelism strategies | `docs/source/features/parallel-strategy.md` |
|
||||||
|
| KV cache | `docs/source/features/kvcache.md` |
|
||||||
|
| API change guidelines | `docs/source/developer-guide/api-change.md` |
|
||||||
|
| Feature compatibility matrix | `docs/source/features/feature-combination-matrix.md` |
|
||||||
|
| Supported models | `docs/source/models/supported-models.md` |
|
||||||
|
| Deployment guides | `docs/source/deployment-guide/` |
|
||||||
|
| Examples & customization | `docs/source/examples/` |
|
||||||
|
| Performance analysis | `docs/source/developer-guide/perf-analysis.md` |
|
||||||
File diff suppressed because it is too large
Load Diff
@ -487,9 +487,27 @@ else:
|
|||||||
f.read()
|
f.read()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Documentation Guidelines
|
||||||
|
|
||||||
|
#### CLI Options in Documentation
|
||||||
|
1. When documenting CLI commands for `trtllm-serve`, `trtllm-bench`, `trtllm-eval`, or similar tools, prefer using `--config` over `--extra_llm_api_options` for specifying configuration files.
|
||||||
|
- `--config` is the preferred, shorter alias for configuration file options.
|
||||||
|
- Example: `trtllm-serve --model <model_path> --config config.yaml` (preferred)
|
||||||
|
- Avoid: `trtllm-serve --model <model_path> --extra_llm_api_options config.yaml`
|
||||||
|
|
||||||
|
## AI Coding Agent Guidance
|
||||||
|
|
||||||
|
This repository includes configuration files for AI coding agents (Claude Code, Cursor, Codex, Copilot, etc.):
|
||||||
|
|
||||||
|
- **`AGENTS.md`** — Shared project context, rules, architecture pointers, and commands. Checked into git.
|
||||||
|
- **`CLAUDE.md`** — Simple `@AGENTS.md` import indirection for claude code.
|
||||||
|
- **`CLAUDE.local.md`** — Personal developer overrides (gitignored). Create this file for your own preferences, local paths, or domain-specific context without affecting the shared config.
|
||||||
|
|
||||||
|
**Keeping `AGENTS.md` up to date**: If you change workflows, commands, architecture, or conventions that would benefit all developers and AI agents, update `AGENTS.md` in the same PR. It should evolve at the pace of the code.
|
||||||
|
|
||||||
## NVIDIA Copyright
|
## NVIDIA Copyright
|
||||||
|
|
||||||
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
|
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
|
||||||
```cpp
|
```cpp
|
||||||
/*
|
/*
|
||||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
|||||||
8
LICENSE
8
LICENSE
@ -41,6 +41,14 @@ Original Source: https://github.com/state-spaces/mamba
|
|||||||
Copyright 2023 Tri Dao, Albert Gu
|
Copyright 2023 Tri Dao, Albert Gu
|
||||||
Licensed under the Apache License 2.0
|
Licensed under the Apache License 2.0
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
Quack
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
Original Source: https://github.com/Dao-AILab/quack
|
||||||
|
Copyright (c) 2025, Tri Dao.
|
||||||
|
Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
||||||
|
Licensed under the Apache License 2.0
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
SGLang
|
SGLang
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|||||||
15
README.md
15
README.md
@ -6,11 +6,12 @@ TensorRT LLM
|
|||||||
state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.</h4>
|
state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.</h4>
|
||||||
|
|
||||||
[](https://nvidia.github.io/TensorRT-LLM/)
|
[](https://nvidia.github.io/TensorRT-LLM/)
|
||||||
|
[](https://deepwiki.com/NVIDIA/TensorRT-LLM)
|
||||||
[](https://www.python.org/downloads/release/python-3123/)
|
[](https://www.python.org/downloads/release/python-3123/)
|
||||||
[](https://www.python.org/downloads/release/python-31012/)
|
[](https://www.python.org/downloads/release/python-31012/)
|
||||||
[](https://developer.nvidia.com/cuda-downloads)
|
[](https://developer.nvidia.com/cuda-downloads)
|
||||||
[](https://pytorch.org)
|
[](https://pytorch.org)
|
||||||
[](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
|
[](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
|
||||||
[](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
|
[](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
|
||||||
|
|
||||||
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html) | [Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html) | [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) | [Documentation](https://nvidia.github.io/TensorRT-LLM/) | [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
|
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html) | [Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html) | [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) | [Documentation](https://nvidia.github.io/TensorRT-LLM/) | [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
|
||||||
@ -20,6 +21,12 @@ state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.<
|
|||||||
|
|
||||||
## Tech Blogs
|
## Tech Blogs
|
||||||
|
|
||||||
|
* [02/06] Accelerating Long-Context Inference with Skip Softmax Attention
|
||||||
|
✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog16_Accelerating_Long_Context_Inference_with_Skip_Softmax_Attention.html)
|
||||||
|
|
||||||
|
* [01/09] Optimizing DeepSeek-V3.2 on NVIDIA Blackwell GPUs
|
||||||
|
✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog15_Optimizing_DeepSeek_V32_on_NVIDIA_Blackwell_GPUs)
|
||||||
|
|
||||||
* [10/13] Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary)
|
* [10/13] Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary)
|
||||||
✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html)
|
✨ [➡️ link](https://nvidia.github.io/TensorRT-LLM/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html)
|
||||||
|
|
||||||
@ -267,5 +274,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
|
|||||||
## Useful Links
|
## Useful Links
|
||||||
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT LLM.
|
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT LLM.
|
||||||
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT LLM.
|
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT LLM.
|
||||||
- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html): A prototype backend for TensorRT LLM to simplify and accelerate the deployment of PyTorch models.
|
- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/features/auto_deploy/auto-deploy.html): A beta backend for TensorRT LLM to simplify and accelerate the deployment of PyTorch models.
|
||||||
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT LLM Q&A and news.
|
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT LLM Q&A and news.
|
||||||
|
|||||||
31
SECURITY.md
Normal file
31
SECURITY.md
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
<!--
|
||||||
|
SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Report a Security Vulnerability
|
||||||
|
|
||||||
|
To report a potential security vulnerability in any NVIDIA product, please use either:
|
||||||
|
* This web form: [Security Vulnerability Submission Form](https://www.nvidia.com/en-us/support/submit-security-vulnerability/), or
|
||||||
|
* Send email to: [NVIDIA PSIRT](mailto:psirt@nvidia.com)
|
||||||
|
|
||||||
|
If reporting a potential vulnerability via email, please encrypt it using NVIDIA’s public PGP key ([see PGP Key page](https://www.nvidia.com/en-us/security/pgp-key/)) and include the following information:
|
||||||
|
1. Product/Driver name and version/branch that contains the vulnerability
|
||||||
|
2. Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
|
||||||
|
3. Instructions to reproduce the vulnerability
|
||||||
|
4. Proof-of-concept or exploit code
|
||||||
|
5. Potential impact of the vulnerability, including how an attacker could exploit the vulnerability
|
||||||
|
|
||||||
|
See https://www.nvidia.com/en-us/security/ for past NVIDIA Security Bulletins and Notices.
|
||||||
@ -2,10 +2,11 @@ import logging
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from datasets import load_dataset
|
import datasets
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from utils.utils import (get_norm_dist_lengths, multimodal_dataset_dump,
|
from utils.utils import (get_norm_dist_lengths, multimodal_dataset_dump,
|
||||||
@ -29,7 +30,7 @@ def validate_output_len_dist(ctx, param, value):
|
|||||||
class DatasetConfig(BaseModel):
|
class DatasetConfig(BaseModel):
|
||||||
"""Dataset configurations."""
|
"""Dataset configurations."""
|
||||||
"""Name of the dataset on HuggingFace."""
|
"""Name of the dataset on HuggingFace."""
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
"""Config name of the dataset if existing."""
|
"""Config name of the dataset if existing."""
|
||||||
config_name: Optional[str] = None
|
config_name: Optional[str] = None
|
||||||
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
|
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
|
||||||
@ -44,6 +45,8 @@ class DatasetConfig(BaseModel):
|
|||||||
prompt: Optional[str] = None
|
prompt: Optional[str] = None
|
||||||
"""The dataset dictionary key used to derive the output sequence length. Set to None if the dataset does not have a key for output."""
|
"""The dataset dictionary key used to derive the output sequence length. Set to None if the dataset does not have a key for output."""
|
||||||
output_key: Optional[str]
|
output_key: Optional[str]
|
||||||
|
"""The local path to the dataset to be loaded when using a local cache."""
|
||||||
|
local_path: Optional[str] = None
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode='after')
|
||||||
def check_prompt(self) -> 'DatasetConfig':
|
def check_prompt(self) -> 'DatasetConfig':
|
||||||
@ -54,19 +57,40 @@ class DatasetConfig(BaseModel):
|
|||||||
raise AssertionError("Either --prompt-key or --prompt must be set.")
|
raise AssertionError("Either --prompt-key or --prompt must be set.")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def check_name_and_local_path(self) -> 'DatasetConfig':
|
||||||
|
if self.name and self.local_path:
|
||||||
|
raise AssertionError(
|
||||||
|
"--dataset-name and --dataset-local-path cannot be set at the same time."
|
||||||
|
)
|
||||||
|
if (not self.name) and (not self.local_path):
|
||||||
|
raise AssertionError(
|
||||||
|
"Either --dataset-name or --dataset-local-path must be set.")
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def query(self):
|
def query(self):
|
||||||
"""Generate the query for HuggingFace `datasets.load_dataset()`"""
|
"""Generate the query for HuggingFace `datasets.load_dataset()`"""
|
||||||
|
first_arg = self.local_path if self.local_path else self.name
|
||||||
|
|
||||||
if self.config_name:
|
if self.config_name:
|
||||||
return [self.name, self.config_name]
|
return [first_arg, self.config_name]
|
||||||
else:
|
else:
|
||||||
return [self.name]
|
return [first_arg]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Returns a human-readable identifier for error messages."""
|
||||||
|
# model_validator ensures exactly one of name or local_path is set
|
||||||
|
if self.name is not None:
|
||||||
|
return self.name
|
||||||
|
return self.local_path
|
||||||
|
|
||||||
def get_prompt(self, req):
|
def get_prompt(self, req):
|
||||||
"""Get the prompt sentence from the given request."""
|
"""Get the prompt sentence from the given request."""
|
||||||
if self.prompt_key:
|
if self.prompt_key:
|
||||||
assert self.prompt_key in req, (
|
assert self.prompt_key in req, (
|
||||||
f"Dataset {self.name} does not have key '{self.prompt_key}'. "
|
f"Dataset {self.display_name} does not have key '{self.prompt_key}'. "
|
||||||
"Please set --prompt-key to one of the available keys: "
|
"Please set --prompt-key to one of the available keys: "
|
||||||
f"{req.keys()}")
|
f"{req.keys()}")
|
||||||
return req[self.prompt_key]
|
return req[self.prompt_key]
|
||||||
@ -76,7 +100,7 @@ class DatasetConfig(BaseModel):
|
|||||||
def get_input(self, req):
|
def get_input(self, req):
|
||||||
"""Get the input sentence from the given request."""
|
"""Get the input sentence from the given request."""
|
||||||
assert self.input_key in req, (
|
assert self.input_key in req, (
|
||||||
f"Dataset {self.name} does not have key '{self.input_key}'. "
|
f"Dataset {self.display_name} does not have key '{self.input_key}'. "
|
||||||
"Please set --input-key to one of the available keys: "
|
"Please set --input-key to one of the available keys: "
|
||||||
f"{req.keys()}")
|
f"{req.keys()}")
|
||||||
return req[self.input_key]
|
return req[self.input_key]
|
||||||
@ -86,7 +110,7 @@ class DatasetConfig(BaseModel):
|
|||||||
image_keys = [self.image_key
|
image_keys = [self.image_key
|
||||||
] + [f"{self.image_key}_{i}" for i in range(1, 8)]
|
] + [f"{self.image_key}_{i}" for i in range(1, 8)]
|
||||||
assert any(key in req for key in image_keys), (
|
assert any(key in req for key in image_keys), (
|
||||||
f"Dataset {self.name} does not have key '{self.image_key}'. "
|
f"Dataset {self.display_name} does not have key '{self.image_key}'. "
|
||||||
"Please set --dataset-image-key to one of the available keys: "
|
"Please set --dataset-image-key to one of the available keys: "
|
||||||
f"{req.keys()}")
|
f"{req.keys()}")
|
||||||
images = []
|
images = []
|
||||||
@ -101,16 +125,47 @@ class DatasetConfig(BaseModel):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"--output-key is not set. Please either:\n"
|
"--output-key is not set. Please either:\n"
|
||||||
"1. Define output length through --output-len-dist.\n"
|
"1. Define output length through --output-len-dist.\n"
|
||||||
f"2. If the dataset {self.name} has key for golden output and "
|
f"2. If the dataset {self.display_name} has key for golden output and "
|
||||||
"you wish to set output length to the length of the golden "
|
"you wish to set output length to the length of the golden "
|
||||||
"output, set --output-key.")
|
"output, set --output-key.")
|
||||||
assert self.output_key in req, (
|
assert self.output_key in req, (
|
||||||
f"Dataset {self.name} does not have key '{self.output_key}'. "
|
f"Dataset {self.display_name} does not have key '{self.output_key}'. "
|
||||||
"Please set --output-key to one of the available keys: "
|
"Please set --output-key to one of the available keys: "
|
||||||
f"{req.keys()}")
|
f"{req.keys()}")
|
||||||
return req[self.output_key]
|
return req[self.output_key]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dataset_load_error(e: ValueError) -> ValueError:
|
||||||
|
"""Create a more informative ValueError from a dataset loading error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The original ValueError from datasets.load_dataset().
|
||||||
|
Returns:
|
||||||
|
A new ValueError with additional context.
|
||||||
|
"""
|
||||||
|
error_msg = str(e)
|
||||||
|
if "Config" in error_msg:
|
||||||
|
error_msg += "\n Please add the config name to the dataset config yaml."
|
||||||
|
elif "split" in error_msg:
|
||||||
|
error_msg += "\n Please specify supported split in the dataset config yaml."
|
||||||
|
return ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(dataset_config: DatasetConfig):
|
||||||
|
"""Load dataset from local path or HuggingFace.
|
||||||
|
Args:
|
||||||
|
dataset_config: A `DatasetConfig` object that defines the dataset to load.
|
||||||
|
Returns:
|
||||||
|
Dataset iterator.
|
||||||
|
Raises:
|
||||||
|
ValueError: When dataset loading fails due to incorrect dataset config setting.
|
||||||
|
"""
|
||||||
|
if dataset_config.local_path:
|
||||||
|
return load_dataset_from_local(dataset_config)
|
||||||
|
else:
|
||||||
|
return load_dataset_from_hf(dataset_config)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_hf(dataset_config: DatasetConfig):
|
def load_dataset_from_hf(dataset_config: DatasetConfig):
|
||||||
"""Load dataset from HuggingFace.
|
"""Load dataset from HuggingFace.
|
||||||
|
|
||||||
@ -121,55 +176,117 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: When dataset loading fails due to incorrect dataset config setting.
|
ValueError: When dataset loading fails due to incorrect dataset config setting.
|
||||||
"""
|
"""
|
||||||
|
logging.debug(
|
||||||
|
f"Loading dataset from HF: query={dataset_config.query}, split={dataset_config.split}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = iter(
|
dataset = iter(
|
||||||
load_dataset(*dataset_config.query,
|
datasets.load_dataset(*dataset_config.query,
|
||||||
split=dataset_config.split,
|
split=dataset_config.split,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
trust_remote_code=True))
|
trust_remote_code=True))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if "Config" in e:
|
raise _create_dataset_load_error(e)
|
||||||
e += "\n Please add the config name to the dataset config yaml."
|
|
||||||
elif "split" in e:
|
logging.debug("Finished loading HF dataset")
|
||||||
e += "\n Please specify supported split in the dataset config yaml."
|
|
||||||
raise ValueError(e)
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_from_local(dataset_config: DatasetConfig):
|
||||||
|
"""Load dataset from local path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_config: A `DatasetConfig` object that defines the dataset to load.
|
||||||
|
Returns:
|
||||||
|
Dataset iterator.
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: When local dataset path does not exist.
|
||||||
|
ValueError: When dataset loading fails due to incorrect dataset config setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
local_path = Path(dataset_config.local_path)
|
||||||
|
|
||||||
|
if not local_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Local dataset path {local_path} does not exist.")
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"Loading dataset from local path: path={local_path}, query={dataset_config.query}, split={dataset_config.split}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If it's a directory we can use the normal loader, otherwise custom loader
|
||||||
|
# depends on the file extension
|
||||||
|
if local_path.is_dir():
|
||||||
|
try:
|
||||||
|
dataset = datasets.load_dataset(*dataset_config.query,
|
||||||
|
split=dataset_config.split,
|
||||||
|
trust_remote_code=True)
|
||||||
|
except ValueError as e:
|
||||||
|
raise _create_dataset_load_error(e)
|
||||||
|
else:
|
||||||
|
format_map = {
|
||||||
|
".json": "json",
|
||||||
|
".jsonl": "json",
|
||||||
|
".csv": "csv",
|
||||||
|
".parquet": "parquet",
|
||||||
|
}
|
||||||
|
|
||||||
|
file_extension = local_path.suffix
|
||||||
|
dataset_type = format_map.get(file_extension)
|
||||||
|
|
||||||
|
if dataset_type is None:
|
||||||
|
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = datasets.load_dataset(dataset_type,
|
||||||
|
data_files=str(local_path),
|
||||||
|
split=dataset_config.split)
|
||||||
|
except ValueError as e:
|
||||||
|
raise _create_dataset_load_error(e)
|
||||||
|
|
||||||
|
logging.debug("Finished loading local dataset")
|
||||||
|
|
||||||
|
return iter(dataset)
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--dataset-name",
|
@click.option("--dataset-name", type=str, help="Dataset name in HuggingFace.")
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help=f"Dataset name in HuggingFace.")
|
|
||||||
@click.option("--dataset-config-name",
|
@click.option("--dataset-config-name",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=f"Dataset config name in HuggingFace (if exists).")
|
help="Dataset config name in HuggingFace (if exists).")
|
||||||
@click.option("--dataset-split",
|
@click.option("--dataset-split",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=f"Split of the dataset to use.")
|
help="Split of the dataset to use.")
|
||||||
@click.option("--dataset-input-key",
|
@click.option("--dataset-input-key",
|
||||||
type=str,
|
type=str,
|
||||||
help=f"The dataset dictionary key for input.")
|
help="The dataset dictionary key for input.")
|
||||||
@click.option("--dataset-image-key",
|
@click.option("--dataset-image-key",
|
||||||
type=str,
|
type=str,
|
||||||
default="image",
|
default="image",
|
||||||
help=f"The dataset dictionary key for images.")
|
help="The dataset dictionary key for images.")
|
||||||
@click.option("--dataset-prompt-key",
|
@click.option("--dataset-prompt-key",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=f"The dataset dictionary key for prompt (if exists).")
|
help="The dataset dictionary key for prompt (if exists).")
|
||||||
|
@click.option(
|
||||||
|
"--dataset-local-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"The local path to the dataset to be loaded when using an offline cache.")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--dataset-prompt",
|
"--dataset-prompt",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=f"The prompt string when there is no prompt key for the dataset.")
|
help="The prompt string when there is no prompt key for the dataset.")
|
||||||
@click.option("--dataset-output-key",
|
@click.option("--dataset-output-key",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=f"The dataset dictionary key for output (if exists).")
|
help="The dataset dictionary key for output (if exists).")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num-requests",
|
"--num-requests",
|
||||||
type=int,
|
type=int,
|
||||||
@ -208,7 +325,7 @@ def dataset(root_args, **kwargs):
|
|||||||
modality = None
|
modality = None
|
||||||
multimodal_texts = []
|
multimodal_texts = []
|
||||||
multimodal_image_paths = []
|
multimodal_image_paths = []
|
||||||
for req in load_dataset_from_hf(dataset_config):
|
for req in load_dataset(dataset_config):
|
||||||
if any(key in req for key in ['image', 'image_1', 'video']):
|
if any(key in req for key in ['image', 'image_1', 'video']):
|
||||||
# multimodal input
|
# multimodal input
|
||||||
if 'video' in req and req['video'] is not None:
|
if 'video' in req and req['video'] is not None:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# These vulnerabilities were inherited from the base image (pytorch:25.10-py3) and should be removed when the base image
|
# These vulnerabilities were inherited from the base image (pytorch:25.12-py3) and should be removed when the base image
|
||||||
# is updated.
|
# is updated.
|
||||||
# WAR against https://github.com/advisories/GHSA-gm62-xv2j-4w53
|
# WAR against https://github.com/advisories/GHSA-38jv-5279-wg99
|
||||||
# WAR against https://github.com/advisories/GHSA-2xpw-w6gg-jr37
|
urllib3>=2.6.3
|
||||||
urllib3>=2.6.0
|
# WAR against https://github.com/advisories/GHSA-8rrh-rw8j-w5fx
|
||||||
|
wheel>=0.46.2
|
||||||
|
|||||||
@ -68,6 +68,7 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
|
|||||||
ON)
|
ON)
|
||||||
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
|
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
|
||||||
"Using open sourced Cutlass AR gemm kernel" ON)
|
"Using open sourced Cutlass AR gemm kernel" ON)
|
||||||
|
option(SKIP_SOFTMAX_STAT "Enable Statistics of Skip-Softmax" OFF)
|
||||||
|
|
||||||
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
|
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
|
||||||
|
|
||||||
@ -82,11 +83,6 @@ endif()
|
|||||||
add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
|
add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
|
||||||
add_compile_definitions("TLLM_ENABLE_CUDA")
|
add_compile_definitions("TLLM_ENABLE_CUDA")
|
||||||
|
|
||||||
set(BINDING_TYPE
|
|
||||||
"nanobind"
|
|
||||||
CACHE STRING
|
|
||||||
"Binding type of Python bindings for C++ runtime and batch manager")
|
|
||||||
|
|
||||||
set(INTERNAL_CUTLASS_KERNELS_PATH
|
set(INTERNAL_CUTLASS_KERNELS_PATH
|
||||||
""
|
""
|
||||||
CACHE
|
CACHE
|
||||||
@ -245,16 +241,15 @@ get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
|
|||||||
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
||||||
add_subdirectory(${3RDPARTY_DIR} 3rdparty)
|
add_subdirectory(${3RDPARTY_DIR} 3rdparty)
|
||||||
|
|
||||||
if(BINDING_TYPE STREQUAL "pybind"
|
if(BUILD_DEEP_EP
|
||||||
OR BUILD_DEEP_EP
|
OR BUILD_DEEP_GEMM
|
||||||
OR BUILD_DEEP_GEMM)
|
OR BUILD_FLASH_MLA)
|
||||||
FetchContent_MakeAvailable(pybind11)
|
FetchContent_MakeAvailable(pybind11)
|
||||||
include_directories(${CMAKE_BINARY_DIR}/_deps/pybind11-src/include)
|
include_directories(${CMAKE_BINARY_DIR}/_deps/pybind11-src/include)
|
||||||
endif()
|
endif()
|
||||||
if(BINDING_TYPE STREQUAL "nanobind")
|
|
||||||
FetchContent_MakeAvailable(nanobind)
|
FetchContent_MakeAvailable(nanobind)
|
||||||
include_directories(${CMAKE_BINARY_DIR}/_deps/nanobind-src/include)
|
include_directories(${CMAKE_BINARY_DIR}/_deps/nanobind-src/include)
|
||||||
endif()
|
|
||||||
|
|
||||||
FetchContent_MakeAvailable(cutlass cxxopts flashmla json xgrammar)
|
FetchContent_MakeAvailable(cutlass cxxopts flashmla json xgrammar)
|
||||||
|
|
||||||
@ -360,6 +355,11 @@ else()
|
|||||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
|
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(SKIP_SOFTMAX_STAT)
|
||||||
|
add_compile_definitions("SKIP_SOFTMAX_STAT")
|
||||||
|
message(STATUS "SKIP_SOFTMAX_STAT is enabled")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
|
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
|
||||||
# be found in
|
# be found in
|
||||||
# https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html#index-mcmodel_003dmedium-1
|
# https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html#index-mcmodel_003dmedium-1
|
||||||
@ -369,6 +369,29 @@ if(CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64)
|
|||||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,--no-relax")
|
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,--no-relax")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# ==== BOLT Compatible Flags ====
|
||||||
|
# These flags enable LLVM BOLT binary compatible:
|
||||||
|
# -fno-reorder-blocks-and-partition: Prevents compiler from splitting hot/cold
|
||||||
|
# blocks -fno-plt: Removes PLT stubs for additional performance
|
||||||
|
# -Wl,--emit-relocs (or -Wl,-q): Preserves relocation info for BOLT Note:
|
||||||
|
# Binaries should NOT be stripped for BOLT to work Usage: cmake
|
||||||
|
# -DENABLE_BOLT_COMPATIBLE=ON ...
|
||||||
|
option(ENABLE_BOLT_COMPATIBLE "Enable BOLT-compatible build flags" OFF)
|
||||||
|
if(ENABLE_BOLT_COMPATIBLE AND NOT WIN32)
|
||||||
|
message(STATUS "BOLT compatible flags enabled")
|
||||||
|
# Compiler flags for C/C++
|
||||||
|
add_compile_options(-fno-reorder-blocks-and-partition -fno-plt)
|
||||||
|
# Linker flags - applies to shared, module, and executable targets
|
||||||
|
add_link_options(-Wl,--emit-relocs)
|
||||||
|
# Disable stripping - required for BOLT (affects pybind11 POST_BUILD strip)
|
||||||
|
set(CMAKE_STRIP
|
||||||
|
""
|
||||||
|
CACHE STRING "Disabled for BOLT compatibility" FORCE)
|
||||||
|
message(STATUS "BOLT: Disabled CMAKE_STRIP to prevent binary stripping")
|
||||||
|
endif()
|
||||||
|
# Note: For nanobind modules, use NOSTRIP option in nanobind_add_module()
|
||||||
|
# ==== End BOLT Compatible Flags ====
|
||||||
|
|
||||||
# Disable deprecated declarations warnings
|
# Disable deprecated declarations warnings
|
||||||
if(NOT WIN32)
|
if(NOT WIN32)
|
||||||
set(CMAKE_CXX_FLAGS "-Wno-deprecated-declarations ${CMAKE_CXX_FLAGS}")
|
set(CMAKE_CXX_FLAGS "-Wno-deprecated-declarations ${CMAKE_CXX_FLAGS}")
|
||||||
|
|||||||
@ -15,10 +15,207 @@
|
|||||||
# the License.
|
# the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
#[=======================================================================[.rst:
|
||||||
|
CudaConfiguration
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
CUDA compiler and architecture configuration for TensorRT-LLM.
|
||||||
|
|
||||||
|
This module provides functions and macros to configure the CUDA compiler,
|
||||||
|
manage CUDA architectures, and filter source files based on target
|
||||||
|
architectures. It is tailored to meet TensorRT-LLM's specific requirements
|
||||||
|
for optimized kernel compilation across multiple GPU generations.
|
||||||
|
|
||||||
|
Macros
|
||||||
|
^^^^^^
|
||||||
|
|
||||||
|
.. command:: setup_cuda_compiler
|
||||||
|
|
||||||
|
Detects and validates the CUDA compiler::
|
||||||
|
|
||||||
|
setup_cuda_compiler()
|
||||||
|
|
||||||
|
This macro determines the CUDA compiler version before enabling the CUDA
|
||||||
|
language extension. It requires CUDA version 11.2 or later.
|
||||||
|
|
||||||
|
The macro sets ``CMAKE_CUDA_COMPILER_VERSION`` upon successful detection.
|
||||||
|
|
||||||
|
Functions
|
||||||
|
^^^^^^^^^
|
||||||
|
|
||||||
|
.. command:: setup_cuda_architectures
|
||||||
|
|
||||||
|
Initializes and normalizes ``CMAKE_CUDA_ARCHITECTURES``::
|
||||||
|
|
||||||
|
setup_cuda_architectures()
|
||||||
|
|
||||||
|
This function processes the ``CMAKE_CUDA_ARCHITECTURES`` variable and
|
||||||
|
configures architecture-specific compilation settings. This function should
|
||||||
|
be called after enabling the CUDA language extension.
|
||||||
|
|
||||||
|
**Special Values for CMAKE_CUDA_ARCHITECTURES:**
|
||||||
|
|
||||||
|
``native``
|
||||||
|
Resolves to the highest available architecture on the system.
|
||||||
|
Falls back to ``all`` if detection fails.
|
||||||
|
|
||||||
|
``all`` or unset
|
||||||
|
Resolves to architectures TensorRT-LLM is optimized for and the
|
||||||
|
compiler supports (80, 86, 89, 90, 100, 103, 120 depending on CUDA version).
|
||||||
|
|
||||||
|
``all-major``
|
||||||
|
Unsupported. Results in a fatal error.
|
||||||
|
|
||||||
|
**Architecture Processing:**
|
||||||
|
|
||||||
|
* PTX is never included in the result binary (``-virtual`` rejected).
|
||||||
|
* The ``-real`` suffix is automatically added to exclude PTX.
|
||||||
|
* Accelerated targets (``-a`` suffix) are used for SM 90+.
|
||||||
|
* On CUDA 12.9+, family targets (``-f`` suffix) are used for SM 100+.
|
||||||
|
|
||||||
|
**Output Variables (set in parent scope):**
|
||||||
|
|
||||||
|
``CMAKE_CUDA_ARCHITECTURES``
|
||||||
|
Normalized list with appropriate suffixes (e.g., ``80-real``, ``90a-real``,
|
||||||
|
``100f-real``).
|
||||||
|
|
||||||
|
``CMAKE_CUDA_ARCHITECTURES_ORIG``
|
||||||
|
Original list of enabled architectures without suffixes.
|
||||||
|
|
||||||
|
``CMAKE_CUDA_ARCHITECTURES_FAMILIES``
|
||||||
|
List of family architectures (e.g., ``100f``, ``120f``).
|
||||||
|
|
||||||
|
``CMAKE_CUDA_ARCHITECTURES_HAS_FAMILIES``
|
||||||
|
Boolean indicating if family targets are supported.
|
||||||
|
|
||||||
|
``CMAKE_CUDA_MIN_ARCHITECTURE_HAS_ACCEL``
|
||||||
|
Minimum architecture supporting accelerated (``-a``) suffix.
|
||||||
|
|
||||||
|
``CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY``
|
||||||
|
Minimum architecture supporting family (``-f``) suffix.
|
||||||
|
|
||||||
|
.. command:: add_cuda_architectures
|
||||||
|
|
||||||
|
Appends CUDA architectures to an existing target::
|
||||||
|
|
||||||
|
add_cuda_architectures(<target> <arch1> [<arch2> ...])
|
||||||
|
|
||||||
|
Adds the specified architectures to ``<target>``'s ``CUDA_ARCHITECTURES``
|
||||||
|
property. The ``-a`` suffix is automatically added for supported
|
||||||
|
architectures. Architectures are only added if they were explicitly
|
||||||
|
requested by the user in ``CMAKE_CUDA_ARCHITECTURES_ORIG``.
|
||||||
|
|
||||||
|
.. command:: set_cuda_architectures
|
||||||
|
|
||||||
|
Sets CUDA architectures for a target::
|
||||||
|
|
||||||
|
set_cuda_architectures(<target> <arch1> [<arch2> ...])
|
||||||
|
|
||||||
|
Replaces the ``CUDA_ARCHITECTURES`` property of ``<target>`` with the
|
||||||
|
specified architectures.
|
||||||
|
|
||||||
|
**Architecture Specification:**
|
||||||
|
|
||||||
|
* Architectures may include the ``f`` suffix for family-conditional
|
||||||
|
compilation (e.g., ``100f``).
|
||||||
|
* Non-family architectures are only added if explicitly requested.
|
||||||
|
* Family architectures are only added if requested architectures would
|
||||||
|
enable compilation for that family.
|
||||||
|
|
||||||
|
If no architectures are enabled for the target, it compiles with
|
||||||
|
``PLACEHOLDER_KERNELS`` macro defined. The kernel source shall compile
|
||||||
|
with any architecture if ``PLACEHOLDER_KERNELS`` macro is defined.
|
||||||
|
|
||||||
|
.. command:: filter_source_cuda_architectures
|
||||||
|
|
||||||
|
Filters source files based on enabled CUDA architectures::
|
||||||
|
|
||||||
|
filter_source_cuda_architectures(
|
||||||
|
SOURCE_LIST <variable>
|
||||||
|
TARGET <target>
|
||||||
|
ARCHS <arch1> [<arch2> ...]
|
||||||
|
[IMPLICIT_FAMILY]
|
||||||
|
)
|
||||||
|
|
||||||
|
Removes source files targeting disabled CUDA architectures from the
|
||||||
|
source list. Files are matched by patterns like ``sm80``, ``sm_80``,
|
||||||
|
``SM80``, etc. in their filenames (for ``.cu`` and ``cubin.cpp`` files).
|
||||||
|
|
||||||
|
``SOURCE_LIST <variable>``
|
||||||
|
Name of the variable containing the list of source files.
|
||||||
|
Modified in place to remove filtered files.
|
||||||
|
|
||||||
|
``TARGET <target>``
|
||||||
|
Target to add compile definitions to. If the target does not exist,
|
||||||
|
an INTERFACE library will be created.
|
||||||
|
|
||||||
|
``ARCHS <arch1> [<arch2> ...]``
|
||||||
|
List of architectures to check. May include ``f`` suffix.
|
||||||
|
|
||||||
|
``IMPLICIT_FAMILY``
|
||||||
|
When set, treats architectures >= ``CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY``
|
||||||
|
as implicitly family-enabled.
|
||||||
|
|
||||||
|
**Defined Macros:**
|
||||||
|
|
||||||
|
For each filtered architecture, a compile definition ``EXCLUDE_SM_<ARCH>``
|
||||||
|
(or ``EXCLUDE_SM_<ARCH>F`` for family architectures) is added to ``<target>``.
|
||||||
|
|
||||||
|
Example
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
include(cuda_configuration)
|
||||||
|
|
||||||
|
# Setup compiler and detect version
|
||||||
|
setup_cuda_compiler()
|
||||||
|
|
||||||
|
# enable_language, or project(project_name LANGUAGES CUDA)
|
||||||
|
# must be called after setup_cuda_compiler() and before
|
||||||
|
# setup_cuda_architectures()
|
||||||
|
enable_language(CUDA)
|
||||||
|
|
||||||
|
# Configure architectures (uses CMAKE_CUDA_ARCHITECTURES if set)
|
||||||
|
setup_cuda_architectures()
|
||||||
|
|
||||||
|
# Add additional architecture to compile for, if it is beneficial.
|
||||||
|
# e.g. Utilizing native FP8 support available in sm89 (Ada)
|
||||||
|
# but not in sm86 (Ampere)
|
||||||
|
# Note: The kernel source must still compiles for all the architectures,
|
||||||
|
# by using less performant implementation.
|
||||||
|
add_library(my_kernels_fp8 STATIC kernels.cu)
|
||||||
|
add_cuda_architectures(my_kernels_fp8 89)
|
||||||
|
|
||||||
|
# Set specific architecture this source should compile for.
|
||||||
|
# e.g. Kernels using WGMMA instructions
|
||||||
|
# Note: The kernel source must still compiles for other architectures when
|
||||||
|
# ``PLACEHOLDER_KERNELS`` macro is defined.
|
||||||
|
add_library(my_kernels_sm90_only STATIC kernels.cu)
|
||||||
|
set_cuda_architectures(my_kernels_sm90_only 90)
|
||||||
|
|
||||||
|
# Filter sources for disabled architectures
|
||||||
|
set(KERNEL_SOURCES
|
||||||
|
kernel_sm80.cubin.cpp
|
||||||
|
kernel_sm90.cubin.cpp
|
||||||
|
kernel_sm100.cubin.cpp
|
||||||
|
)
|
||||||
|
filter_source_cuda_architectures(
|
||||||
|
SOURCE_LIST KERNEL_SOURCES
|
||||||
|
TARGET my_kernel_interface
|
||||||
|
ARCHS 80 90 100
|
||||||
|
)
|
||||||
|
# ``my_kernel_interface`` target is created with definitions to exclude
|
||||||
|
# disabled architectures.
|
||||||
|
|
||||||
|
#]=======================================================================]
|
||||||
|
|
||||||
|
#[[
|
||||||
|
Determine CUDA version before enabling the language extension
|
||||||
|
check_language(CUDA) clears CMAKE_CUDA_HOST_COMPILER if CMAKE_CUDA_COMPILER
|
||||||
|
is not set
|
||||||
|
#]]
|
||||||
macro(setup_cuda_compiler)
|
macro(setup_cuda_compiler)
|
||||||
# Determine CUDA version before enabling the language extension
|
|
||||||
# check_language(CUDA) clears CMAKE_CUDA_HOST_COMPILER if CMAKE_CUDA_COMPILER
|
|
||||||
# is not set
|
|
||||||
include(CheckLanguage)
|
include(CheckLanguage)
|
||||||
if(NOT CMAKE_CUDA_COMPILER AND CMAKE_CUDA_HOST_COMPILER)
|
if(NOT CMAKE_CUDA_COMPILER AND CMAKE_CUDA_HOST_COMPILER)
|
||||||
set(CMAKE_CUDA_HOST_COMPILER_BACKUP ${CMAKE_CUDA_HOST_COMPILER})
|
set(CMAKE_CUDA_HOST_COMPILER_BACKUP ${CMAKE_CUDA_HOST_COMPILER})
|
||||||
@ -70,25 +267,28 @@ macro(setup_cuda_compiler)
|
|||||||
endif()
|
endif()
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
function(setup_cuda_architectures)
|
#[[
|
||||||
# cmake-format: off
|
Initialize and normalize CMAKE_CUDA_ARCHITECTURES.
|
||||||
# Initialize and normalize CMAKE_CUDA_ARCHITECTURES.
|
|
||||||
# Special values:
|
|
||||||
# * `native` is resolved to HIGHEST available architecture.
|
|
||||||
# * Fallback to `all` if detection failed.
|
|
||||||
# * `all`/unset is resolved to a set of architectures we optimized for and compiler supports.
|
|
||||||
# * `all-major` is unsupported.
|
|
||||||
# Numerical architectures:
|
|
||||||
# * PTX is never included in result binary.
|
|
||||||
# * `*-virtual` architectures are therefore rejected.
|
|
||||||
# * `-real` suffix is automatically added to exclude PTX.
|
|
||||||
# * Always use accelerated (`-a` suffix) target for supported architectures.
|
|
||||||
# * On CUDA 12.9 or newer, family (`-f` suffix) target will be used for supported architectures to reduce number of
|
|
||||||
# targets to compile for.
|
|
||||||
# * Extra architectures can be requested via add_cuda_architectures
|
|
||||||
# for kernels that benefit from arch specific features.
|
|
||||||
# cmake-format: on
|
|
||||||
|
|
||||||
|
Special values:
|
||||||
|
|
||||||
|
* `native` is resolved to HIGHEST available architecture.
|
||||||
|
* Fallback to `all` if detection failed.
|
||||||
|
* `all`/unset is resolved to a set of architectures we optimized for and compiler supports.
|
||||||
|
* `all-major` is unsupported.
|
||||||
|
|
||||||
|
Numerical architectures:
|
||||||
|
|
||||||
|
* PTX is never included in result binary.
|
||||||
|
* `*-virtual` architectures are therefore rejected.
|
||||||
|
* `-real` suffix is automatically added to exclude PTX.
|
||||||
|
* Always use accelerated (`-a` suffix) target for supported architectures.
|
||||||
|
* On CUDA 12.9 or newer, family (`-f` suffix) target will be used for supported architectures to reduce number of
|
||||||
|
targets to compile for.
|
||||||
|
* Extra architectures can be requested via add_cuda_architectures
|
||||||
|
for kernels that benefit from arch specific features.
|
||||||
|
#]]
|
||||||
|
function(setup_cuda_architectures)
|
||||||
set(CMAKE_CUDA_ARCHITECTURES_RAW ${CMAKE_CUDA_ARCHITECTURES})
|
set(CMAKE_CUDA_ARCHITECTURES_RAW ${CMAKE_CUDA_ARCHITECTURES})
|
||||||
if(CMAKE_CUDA_ARCHITECTURES_RAW STREQUAL "native")
|
if(CMAKE_CUDA_ARCHITECTURES_RAW STREQUAL "native")
|
||||||
# Detect highest available compute capability
|
# Detect highest available compute capability
|
||||||
@ -138,9 +338,6 @@ function(setup_cuda_architectures)
|
|||||||
message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}")
|
message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}")
|
||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
if("103" IN_LIST CMAKE_CUDA_ARCHITECTURES_CLEAN)
|
|
||||||
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN "100")
|
|
||||||
endif()
|
|
||||||
list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN)
|
list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN)
|
||||||
set(CMAKE_CUDA_ARCHITECTURES_RAW ${CMAKE_CUDA_ARCHITECTURES_CLEAN})
|
set(CMAKE_CUDA_ARCHITECTURES_RAW ${CMAKE_CUDA_ARCHITECTURES_CLEAN})
|
||||||
endif()
|
endif()
|
||||||
@ -182,22 +379,29 @@ function(setup_cuda_architectures)
|
|||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
# -a suffix supported from Hopper (90)
|
# -a suffix supported from Hopper (90)
|
||||||
set(MIN_ARCHITECTURE_HAS_ACCEL 90)
|
set(CMAKE_CUDA_MIN_ARCHITECTURE_HAS_ACCEL 90)
|
||||||
|
set(CMAKE_CUDA_MIN_ARCHITECTURE_HAS_ACCEL
|
||||||
|
${CMAKE_CUDA_MIN_ARCHITECTURE_HAS_ACCEL}
|
||||||
|
PARENT_SCOPE)
|
||||||
# -f suffix supported from Blackwell (100) starting from CUDA 12.9.
|
# -f suffix supported from Blackwell (100) starting from CUDA 12.9.
|
||||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9")
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9")
|
||||||
set(MIN_ARCHITECTURE_HAS_FAMILY 100)
|
set(CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY 100)
|
||||||
set(CMAKE_CUDA_ARCHITECTURES_HAS_FAMILIES
|
set(CMAKE_CUDA_ARCHITECTURES_HAS_FAMILIES
|
||||||
ON
|
ON
|
||||||
PARENT_SCOPE)
|
PARENT_SCOPE)
|
||||||
else()
|
else()
|
||||||
# -a provides no cross architecture compatibility, but luckily until CUDA
|
# -a provides no cross architecture compatibility, but luckily until CUDA
|
||||||
# 12.8 We have only one architecture within each family >= 9.
|
# 12.8 We have only one architecture within each family >= 9.
|
||||||
set(MIN_ARCHITECTURE_HAS_FAMILY 9999) # Effectively exclude all
|
set(CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY 9999) # Effectively exclude all
|
||||||
# architectures
|
# architectures
|
||||||
set(CMAKE_CUDA_ARCHITECTURES_HAS_FAMILIES
|
set(CMAKE_CUDA_ARCHITECTURES_HAS_FAMILIES
|
||||||
OFF
|
OFF
|
||||||
PARENT_SCOPE)
|
PARENT_SCOPE)
|
||||||
endif()
|
endif()
|
||||||
|
set(CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY
|
||||||
|
${CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY}
|
||||||
|
PARENT_SCOPE)
|
||||||
|
|
||||||
# Compatibility low bounds: Always compile kernels for these architectures. 86
|
# Compatibility low bounds: Always compile kernels for these architectures. 86
|
||||||
# is enabled to avoid perf regression when using 80 kernels.
|
# is enabled to avoid perf regression when using 80 kernels.
|
||||||
set(ARCHITECTURES_COMPATIBILITY_BASE 80 86 90 100 120)
|
set(ARCHITECTURES_COMPATIBILITY_BASE 80 86 90 100 120)
|
||||||
@ -252,11 +456,11 @@ function(setup_cuda_architectures)
|
|||||||
set(CMAKE_CUDA_ARCHITECTURES_NORMALIZED)
|
set(CMAKE_CUDA_ARCHITECTURES_NORMALIZED)
|
||||||
set(CMAKE_CUDA_ARCHITECTURES_FAMILIES)
|
set(CMAKE_CUDA_ARCHITECTURES_FAMILIES)
|
||||||
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES_NORMALIZED_LIST)
|
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES_NORMALIZED_LIST)
|
||||||
if(CUDA_ARCH GREATER_EQUAL ${MIN_ARCHITECTURE_HAS_FAMILY}
|
if(CUDA_ARCH GREATER_EQUAL ${CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY}
|
||||||
AND NOT CUDA_ARCH IN_LIST ARCHITECTURES_NO_COMPATIBILITY)
|
AND NOT CUDA_ARCH IN_LIST ARCHITECTURES_NO_COMPATIBILITY)
|
||||||
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}f-real")
|
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}f-real")
|
||||||
list(APPEND CMAKE_CUDA_ARCHITECTURES_FAMILIES "${CUDA_ARCH}f")
|
list(APPEND CMAKE_CUDA_ARCHITECTURES_FAMILIES "${CUDA_ARCH}f")
|
||||||
elseif(CUDA_ARCH GREATER_EQUAL ${MIN_ARCHITECTURE_HAS_ACCEL})
|
elseif(CUDA_ARCH GREATER_EQUAL ${CMAKE_CUDA_MIN_ARCHITECTURE_HAS_ACCEL})
|
||||||
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real")
|
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real")
|
||||||
else()
|
else()
|
||||||
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real")
|
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real")
|
||||||
@ -271,17 +475,15 @@ function(setup_cuda_architectures)
|
|||||||
PARENT_SCOPE)
|
PARENT_SCOPE)
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
#[[
|
||||||
|
Add CUDA architectures to target.
|
||||||
|
-a suffix is added automatically for supported architectures.
|
||||||
|
Architectures are added only if user explicitly requested support for that architecture.
|
||||||
|
#]]
|
||||||
function(add_cuda_architectures target)
|
function(add_cuda_architectures target)
|
||||||
# cmake-format: off
|
|
||||||
# Add CUDA architectures to target.
|
|
||||||
# -a suffix is added automatically for supported architectures.
|
|
||||||
# Architectures are added only if user explicitly requested support for that architecture.
|
|
||||||
# cmake-format: on
|
|
||||||
set(MIN_ARCHITECTURE_HAS_ACCEL 90)
|
|
||||||
|
|
||||||
foreach(CUDA_ARCH IN LISTS ARGN)
|
foreach(CUDA_ARCH IN LISTS ARGN)
|
||||||
if(${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
if(${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
||||||
if(${CUDA_ARCH} GREATER_EQUAL ${MIN_ARCHITECTURE_HAS_ACCEL})
|
if(${CUDA_ARCH} GREATER_EQUAL ${CMAKE_CUDA_MIN_ARCHITECTURE_HAS_ACCEL})
|
||||||
set(REAL_CUDA_ARCH "${CUDA_ARCH}a-real")
|
set(REAL_CUDA_ARCH "${CUDA_ARCH}a-real")
|
||||||
else()
|
else()
|
||||||
set(REAL_CUDA_ARCH "${CUDA_ARCH}-real")
|
set(REAL_CUDA_ARCH "${CUDA_ARCH}-real")
|
||||||
@ -294,18 +496,19 @@ function(add_cuda_architectures target)
|
|||||||
endforeach()
|
endforeach()
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
function(set_cuda_architectures target)
|
#[[
|
||||||
# cmake-format: off
|
Set CUDA architectures for a target.
|
||||||
# Set CUDA architectures for a target.
|
|
||||||
# -a suffix is added automatically for supported architectures.
|
|
||||||
# Architectures passed in may be specified with -f suffix to build family conditional version of the kernel.
|
|
||||||
# Non-family architectures are added only if user explicitly requested support for that architecture.
|
|
||||||
# Family conditional architectures are only added if user requested architectures would enable compilation for it.
|
|
||||||
# If user requested no architectures set on the target,
|
|
||||||
# the target will be compiled with `PLACEHOLDER_KERNELS` macro defined.
|
|
||||||
# cmake-format: on
|
|
||||||
set(MIN_ARCHITECTURE_HAS_ACCEL 90)
|
|
||||||
|
|
||||||
|
-a suffix is added automatically for supported architectures.
|
||||||
|
Architectures passed in may be specified with -f suffix to build family conditional version of the kernel.
|
||||||
|
|
||||||
|
Non-family architectures are added only if user explicitly requested support for that architecture.
|
||||||
|
Family conditional architectures are only added if user requested architectures would enable compilation for it.
|
||||||
|
|
||||||
|
If user requested no architectures set on the target,
|
||||||
|
the target will be compiled with `PLACEHOLDER_KERNELS` macro defined.
|
||||||
|
#]]
|
||||||
|
function(set_cuda_architectures target)
|
||||||
set(CUDA_ARCHITECTURES "")
|
set(CUDA_ARCHITECTURES "")
|
||||||
foreach(CUDA_ARCH IN LISTS ARGN)
|
foreach(CUDA_ARCH IN LISTS ARGN)
|
||||||
if(${CUDA_ARCH} MATCHES "[0-9]+f")
|
if(${CUDA_ARCH} MATCHES "[0-9]+f")
|
||||||
@ -326,7 +529,7 @@ function(set_cuda_architectures target)
|
|||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
elseif(${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
elseif(${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
||||||
if(${CUDA_ARCH} GREATER_EQUAL ${MIN_ARCHITECTURE_HAS_ACCEL})
|
if(${CUDA_ARCH} GREATER_EQUAL ${CMAKE_CUDA_MIN_ARCHITECTURE_HAS_ACCEL})
|
||||||
list(APPEND CUDA_ARCHITECTURES "${CUDA_ARCH}a-real")
|
list(APPEND CUDA_ARCHITECTURES "${CUDA_ARCH}a-real")
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_ARCHITECTURES "${CUDA_ARCH}-real")
|
list(APPEND CUDA_ARCHITECTURES "${CUDA_ARCH}-real")
|
||||||
@ -342,3 +545,153 @@ function(set_cuda_architectures target)
|
|||||||
${CUDA_ARCHITECTURES})
|
${CUDA_ARCHITECTURES})
|
||||||
endif()
|
endif()
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
#[[
|
||||||
|
Filter out source files targeting CUDA architectures not enabled.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
SOURCE_LIST - Name of the variable containing the list of source files to filter
|
||||||
|
TARGET - Target to add compile definitions to. If the target does not exist,
|
||||||
|
an INTERFACE library will be created.
|
||||||
|
ARCHS - List of architectures to check and potentially filter
|
||||||
|
IMPLICIT_FAMILY - Optional flag to enable implicit family mode
|
||||||
|
|
||||||
|
For each ARCH passed in:
|
||||||
|
|
||||||
|
- if IMPLICIT_FAMILY is not set:
|
||||||
|
- if ARCH is not suffixed by f:
|
||||||
|
if ARCH is not in CMAKE_CUDA_ARCHITECTURES_ORIG, source files containing "sm${ARCH}"
|
||||||
|
but not "sm${ARCH}f" (case insensitive) will be excluded
|
||||||
|
Macro "EXCLUDE_SM_${ARCH}" will be defined on TARGET
|
||||||
|
- if ARCH is suffixed by f, NARCH is ARCH without f suffix:
|
||||||
|
if ARCH is not in CMAKE_CUDA_ARCHITECTURES_FAMILIES, source files containing
|
||||||
|
"sm${NARCH}f" (case insensitive) will be excluded
|
||||||
|
Macro "EXCLUDE_SM_${NARCH}F" will be defined on TARGET
|
||||||
|
|
||||||
|
- if IMPLICIT_FAMILY is set:
|
||||||
|
ARCH shall not suffixed by f.
|
||||||
|
- if ARCH >= CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY:
|
||||||
|
if "${ARCH}f" is not in CMAKE_CUDA_ARCHITECTURES_FAMILIES,
|
||||||
|
source files containing "sm${ARCH}" but not "sm${ARCH}a" (case insensitive) will be excluded
|
||||||
|
Macro "EXCLUDE_SM_${ARCH}" (no F) will be defined on TARGET
|
||||||
|
- else:
|
||||||
|
if "${ARCH}" is not in CMAKE_CUDA_ARCHITECTURES_ORIG,
|
||||||
|
source files containing "sm${ARCH}" (case insensitive) will be excluded
|
||||||
|
Macro "EXCLUDE_SM_${ARCH}" will be defined on TARGET
|
||||||
|
#]]
|
||||||
|
function(filter_source_cuda_architectures)
|
||||||
|
set(options IMPLICIT_FAMILY)
|
||||||
|
set(oneValueArgs SOURCE_LIST TARGET)
|
||||||
|
set(multiValueArgs ARCHS)
|
||||||
|
|
||||||
|
cmake_parse_arguments(PARSE_ARGV 0 arg "${options}" "${oneValueArgs}"
|
||||||
|
"${multiValueArgs}")
|
||||||
|
set(SOURCES "${${arg_SOURCE_LIST}}")
|
||||||
|
|
||||||
|
if(NOT TARGET ${arg_TARGET})
|
||||||
|
add_library(${arg_TARGET} INTERFACE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Determine if target is INTERFACE library to use correct visibility
|
||||||
|
get_target_property(_target_type ${arg_TARGET} TYPE)
|
||||||
|
if(_target_type STREQUAL "INTERFACE_LIBRARY")
|
||||||
|
set(_compile_def_visibility INTERFACE)
|
||||||
|
else()
|
||||||
|
set(_compile_def_visibility PUBLIC)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
foreach(ARCH IN LISTS arg_ARCHS)
|
||||||
|
set(SHOULD_FILTER FALSE)
|
||||||
|
set(MATCH_PATTERN "")
|
||||||
|
set(EXCLUDE_PATTERN "")
|
||||||
|
set(ARCH_FOR_DEFINE "")
|
||||||
|
|
||||||
|
if(NOT arg_IMPLICIT_FAMILY)
|
||||||
|
# Check if ARCH ends with 'f'
|
||||||
|
string(REGEX MATCH "^(.+)f$" _has_f_suffix "${ARCH}")
|
||||||
|
|
||||||
|
if(_has_f_suffix)
|
||||||
|
# ARCH is suffixed by 'f' (e.g., "100f")
|
||||||
|
set(BASE_ARCH "${CMAKE_MATCH_1}")
|
||||||
|
if(NOT "${ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_FAMILIES)
|
||||||
|
set(SHOULD_FILTER TRUE)
|
||||||
|
set(ARCH_FOR_DEFINE "${BASE_ARCH}F")
|
||||||
|
# Match "sm${BASE_ARCH}f" - straightforward match, no exclusion
|
||||||
|
# pattern needed
|
||||||
|
set(MATCH_PATTERN ".*[Ss][Mm]_?${BASE_ARCH}f.*(cubin\.cpp|\.cu)$")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# ARCH is NOT suffixed by 'f' (e.g., "80")
|
||||||
|
if(NOT "${ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
||||||
|
set(SHOULD_FILTER TRUE)
|
||||||
|
set(ARCH_FOR_DEFINE "${ARCH}")
|
||||||
|
# Match "sm${ARCH}" but NOT "sm${ARCH}f"
|
||||||
|
set(MATCH_PATTERN ".*[Ss][Mm]_?${ARCH}.*(cubin\.cpp|\.cu)$")
|
||||||
|
set(EXCLUDE_PATTERN ".*[Ss][Mm]_?${ARCH}f.*(cubin\.cpp|\.cu)$")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# IMPLICIT_FAMILY is set - ARCH shall not be suffixed by 'f'
|
||||||
|
if(${ARCH} GREATER_EQUAL ${CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY})
|
||||||
|
# ARCH >= CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY
|
||||||
|
if(NOT "${ARCH}f" IN_LIST CMAKE_CUDA_ARCHITECTURES_FAMILIES)
|
||||||
|
set(SHOULD_FILTER TRUE)
|
||||||
|
set(ARCH_FOR_DEFINE "${ARCH}")
|
||||||
|
# Match "sm${ARCH}" but NOT "sm${ARCH}a"
|
||||||
|
set(MATCH_PATTERN ".*[Ss][Mm]_?${ARCH}.*(cubin\.cpp|\.cu)$")
|
||||||
|
set(EXCLUDE_PATTERN ".*[Ss][Mm]_?${ARCH}a.*(cubin\.cpp|\.cu)$")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# ARCH < CMAKE_CUDA_MIN_ARCHITECTURE_HAS_FAMILY
|
||||||
|
if(NOT "${ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
|
||||||
|
set(SHOULD_FILTER TRUE)
|
||||||
|
set(ARCH_FOR_DEFINE "${ARCH}")
|
||||||
|
# Match "sm${ARCH}" - no exclusion pattern needed
|
||||||
|
set(MATCH_PATTERN ".*[Ss][Mm]_?${ARCH}.*(cubin\.cpp|\.cu)$")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(SHOULD_FILTER)
|
||||||
|
# Get files matching the main pattern
|
||||||
|
set(SOURCES_TO_CHECK "${SOURCES}")
|
||||||
|
list(FILTER SOURCES_TO_CHECK INCLUDE REGEX "${MATCH_PATTERN}")
|
||||||
|
|
||||||
|
if(NOT "${EXCLUDE_PATTERN}" STREQUAL "")
|
||||||
|
# Find files matching the exclusion pattern (these should be kept)
|
||||||
|
set(SOURCES_TO_KEEP "${SOURCES_TO_CHECK}")
|
||||||
|
list(FILTER SOURCES_TO_KEEP INCLUDE REGEX "${EXCLUDE_PATTERN}")
|
||||||
|
# Remove the files we want to keep from the check list
|
||||||
|
if(SOURCES_TO_KEEP)
|
||||||
|
list(REMOVE_ITEM SOURCES_TO_CHECK ${SOURCES_TO_KEEP})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(SOURCES_FILTERED "${SOURCES_TO_CHECK}")
|
||||||
|
|
||||||
|
list(LENGTH SOURCES_FILTERED SOURCES_FILTERED_LEN)
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Excluding ${SOURCES_FILTERED_LEN} cubins for SM ${ARCH} from ${CMAKE_CURRENT_SOURCE_DIR}"
|
||||||
|
)
|
||||||
|
foreach(filtered_item IN LISTS SOURCES_FILTERED)
|
||||||
|
message(VERBOSE "- ${filtered_item}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
# Remove filtered files from sources
|
||||||
|
if(SOURCES_FILTERED)
|
||||||
|
list(REMOVE_ITEM SOURCES ${SOURCES_FILTERED})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Add compile definition to target
|
||||||
|
target_compile_definitions(
|
||||||
|
${arg_TARGET}
|
||||||
|
${_compile_def_visibility}
|
||||||
|
"EXCLUDE_SM_${ARCH_FOR_DEFINE}")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
set(${arg_SOURCE_LIST}
|
||||||
|
"${SOURCES}"
|
||||||
|
PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|||||||
9
cpp/conan.lock
Normal file
9
cpp/conan.lock
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"version": "0.5",
|
||||||
|
"requires": [
|
||||||
|
"libnuma/system#65d9e0e45ccc1e477b97d678fa7d56bb%1769137587.8781652"
|
||||||
|
],
|
||||||
|
"build_requires": [],
|
||||||
|
"python_requires": [],
|
||||||
|
"config_requires": []
|
||||||
|
}
|
||||||
@ -190,6 +190,14 @@ public:
|
|||||||
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
|
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct RequestStatuses
|
||||||
|
{
|
||||||
|
/// Requests that have completed their transfer successfully.
|
||||||
|
std::unordered_set<LlmRequest::RequestIdType> completedRequestIds;
|
||||||
|
/// Requests that have encountered an error during their transfer.
|
||||||
|
std::unordered_set<LlmRequest::RequestIdType> errorRequestIds;
|
||||||
|
};
|
||||||
|
|
||||||
class BaseCacheTransceiver
|
class BaseCacheTransceiver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -202,7 +210,10 @@ public:
|
|||||||
virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
|
virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
|
||||||
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;
|
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;
|
||||||
|
|
||||||
virtual void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
|
/// Check all requests transferring context, and return the requests that have completed or encountered an error.
|
||||||
|
virtual RequestStatuses checkContextTransferStatus(
|
||||||
|
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false)
|
||||||
|
= 0;
|
||||||
|
|
||||||
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
|
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
|
||||||
|
|
||||||
@ -243,7 +254,8 @@ public:
|
|||||||
void requestAndReceiveSync(LlmRequest* llmRequest) override;
|
void requestAndReceiveSync(LlmRequest* llmRequest) override;
|
||||||
void requestAndReceiveAsync(LlmRequest* llmRequest) override;
|
void requestAndReceiveAsync(LlmRequest* llmRequest) override;
|
||||||
|
|
||||||
void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
|
RequestStatuses checkContextTransferStatus(
|
||||||
|
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;
|
||||||
|
|
||||||
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
|
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
|
||||||
|
|
||||||
|
|||||||
@ -288,6 +288,9 @@ public:
|
|||||||
|
|
||||||
void removeNextBlock(BlockKey const& blockKey);
|
void removeNextBlock(BlockKey const& blockKey);
|
||||||
|
|
||||||
|
void freeDescendantsRecursively();
|
||||||
|
void freeBlockAndAllDescendants();
|
||||||
|
|
||||||
//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
|
//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
|
||||||
//! blockKey.
|
//! blockKey.
|
||||||
//! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were
|
//! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were
|
||||||
@ -365,6 +368,9 @@ private:
|
|||||||
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
|
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
|
||||||
// Hash for the event manager
|
// Hash for the event manager
|
||||||
size_t mHash;
|
size_t mHash;
|
||||||
|
|
||||||
|
// Mutex for the next blocks
|
||||||
|
mutable std::mutex mNextBlocksMutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GenerationRequest
|
class GenerationRequest
|
||||||
@ -380,6 +386,7 @@ public:
|
|||||||
, mBeamWidth(beamWidth)
|
, mBeamWidth(beamWidth)
|
||||||
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
|
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
|
||||||
, mNumFrontBlocksRemoved(0)
|
, mNumFrontBlocksRemoved(0)
|
||||||
|
, mCurrentPrepopulatedPromptLen(std::numeric_limits<SizeType32>::max())
|
||||||
{
|
{
|
||||||
auto const numWindowSizes = windowSizeToMetadata.size();
|
auto const numWindowSizes = windowSizeToMetadata.size();
|
||||||
mCacheBlockIds.reserve(numWindowSizes);
|
mCacheBlockIds.reserve(numWindowSizes);
|
||||||
@ -500,6 +507,20 @@ public:
|
|||||||
return mKvCacheRetentionConfig.getDirectory();
|
return mKvCacheRetentionConfig.getDirectory();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] SizeType32 getCurrentPrepopulatedPromptLen() const
|
||||||
|
{
|
||||||
|
return mCurrentPrepopulatedPromptLen;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setCurrentPrepopulatedPromptLen(SizeType32 currentPrepopulatedPromptLen)
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(currentPrepopulatedPromptLen <= mCurrentPrepopulatedPromptLen,
|
||||||
|
"currentPrepopulatedPromptLen must be updated non-increasingly due to the "
|
||||||
|
"assumption that smaller window sizes have shorter or equal"
|
||||||
|
"currentPrepopulatedPromptLen in WindowSizeManager::loadOrAllocateBlocks.");
|
||||||
|
mCurrentPrepopulatedPromptLen = currentPrepopulatedPromptLen;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Request id of the sequence
|
// Request id of the sequence
|
||||||
LlmRequest::RequestIdType mRequestId;
|
LlmRequest::RequestIdType mRequestId;
|
||||||
@ -517,6 +538,8 @@ private:
|
|||||||
SizeType32 mNumFrontBlocksRemoved;
|
SizeType32 mNumFrontBlocksRemoved;
|
||||||
// Set of used blocks by the sequence
|
// Set of used blocks by the sequence
|
||||||
std::set<KVCacheBlock::IdType> mUsedBlocks;
|
std::set<KVCacheBlock::IdType> mUsedBlocks;
|
||||||
|
// Current prepopulated prompt length
|
||||||
|
SizeType32 mCurrentPrepopulatedPromptLen;
|
||||||
};
|
};
|
||||||
|
|
||||||
// attach metadata to a pool pointer
|
// attach metadata to a pool pointer
|
||||||
@ -619,7 +642,8 @@ public:
|
|||||||
void startScheduling();
|
void startScheduling();
|
||||||
|
|
||||||
//! \brief Assign blocks for new sequence. Try to reuse blocks.
|
//! \brief Assign blocks for new sequence. Try to reuse blocks.
|
||||||
void addSequence(
|
//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
|
||||||
|
[[nodiscard]] SizeType32 addSequence(
|
||||||
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
|
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
|
||||||
|
|
||||||
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
||||||
@ -631,7 +655,7 @@ public:
|
|||||||
|
|
||||||
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
|
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
|
||||||
|
|
||||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
|
||||||
|
|
||||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||||
@ -836,8 +860,8 @@ public:
|
|||||||
//! \param blockKeys Key of each block.
|
//! \param blockKeys Key of each block.
|
||||||
//! \param blockIds Id of each block.
|
//! \param blockIds Id of each block.
|
||||||
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
|
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
|
||||||
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
|
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
|
||||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||||
bool pinBlocks = false);
|
bool pinBlocks = false);
|
||||||
|
|
||||||
@ -867,10 +891,15 @@ public:
|
|||||||
return mIsSWA;
|
return mIsSWA;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isEnablePartialReuse() const
|
||||||
|
{
|
||||||
|
return mEnablePartialReuse;
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
|
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
|
||||||
|
|
||||||
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
|
//! \brief Unpin blocks by block ids directly
|
||||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||||
|
|
||||||
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
|
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
|
||||||
{
|
{
|
||||||
@ -1001,7 +1030,7 @@ private:
|
|||||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
|
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
|
||||||
|
|
||||||
// Mutex for the cached blocks root
|
// Mutex for the cached blocks root
|
||||||
std::mutex mCachedBlocksRootMutex;
|
mutable std::mutex mCachedBlocksRootMutex;
|
||||||
|
|
||||||
// Record which sequence is using the block
|
// Record which sequence is using the block
|
||||||
std::map<KVCacheBlock::IdType, LlmRequest::RequestIdType> mBlockToSequence;
|
std::map<KVCacheBlock::IdType, LlmRequest::RequestIdType> mBlockToSequence;
|
||||||
@ -1054,6 +1083,11 @@ public:
|
|||||||
return mIndexerKCacheIndexHeadDim;
|
return mIndexerKCacheIndexHeadDim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isEnablePartialReuse() const
|
||||||
|
{
|
||||||
|
return mWindowBlockManagers.begin()->second.isEnablePartialReuse();
|
||||||
|
}
|
||||||
|
|
||||||
BlockManager(BlockManager const&) = delete;
|
BlockManager(BlockManager const&) = delete;
|
||||||
BlockManager& operator=(BlockManager const&) = delete;
|
BlockManager& operator=(BlockManager const&) = delete;
|
||||||
|
|
||||||
@ -1068,8 +1102,9 @@ public:
|
|||||||
|
|
||||||
void allocatePools(bool useUvm);
|
void allocatePools(bool useUvm);
|
||||||
|
|
||||||
void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
|
//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
|
||||||
LlmRequest& llmRequest, SizeType32 windowSize);
|
[[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength,
|
||||||
|
SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize);
|
||||||
|
|
||||||
//! \brief Assign blocks for a new sequence.
|
//! \brief Assign blocks for a new sequence.
|
||||||
//! \param sequence The GenerationRequest to process.
|
//! \param sequence The GenerationRequest to process.
|
||||||
@ -1086,7 +1121,7 @@ public:
|
|||||||
std::optional<KVCacheBlock::IdType> releaseBlocks(
|
std::optional<KVCacheBlock::IdType> releaseBlocks(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||||
|
|
||||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||||
|
|
||||||
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
|
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
|
||||||
@ -1095,7 +1130,7 @@ public:
|
|||||||
/// @param sequence The generation request whose blocks should be pinned.
|
/// @param sequence The generation request whose blocks should be pinned.
|
||||||
void pinBlocks(GenerationRequest& sequence);
|
void pinBlocks(GenerationRequest& sequence);
|
||||||
|
|
||||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||||
|
|
||||||
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
|
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
|
||||||
|
|
||||||
@ -1116,7 +1151,7 @@ public:
|
|||||||
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
|
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
|
||||||
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
|
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
|
||||||
|
|
||||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||||
SizeType32 windowSize, bool pinBlocks = false)
|
SizeType32 windowSize, bool pinBlocks = false)
|
||||||
{
|
{
|
||||||
@ -1540,6 +1575,8 @@ public:
|
|||||||
|
|
||||||
[[nodiscard]] virtual bool isEnableBlockReuse() const = 0;
|
[[nodiscard]] virtual bool isEnableBlockReuse() const = 0;
|
||||||
|
|
||||||
|
[[nodiscard]] virtual bool isEnablePartialReuse() const = 0;
|
||||||
|
|
||||||
[[nodiscard]] virtual bool isEnableIndexerKCache() const = 0;
|
[[nodiscard]] virtual bool isEnableIndexerKCache() const = 0;
|
||||||
[[nodiscard]] virtual SizeType32 getIndexerKCacheIndexHeadDim() const = 0;
|
[[nodiscard]] virtual SizeType32 getIndexerKCacheIndexHeadDim() const = 0;
|
||||||
[[nodiscard]] virtual SizeType32 getIndexerKCacheQuantBlockSize() const = 0;
|
[[nodiscard]] virtual SizeType32 getIndexerKCacheQuantBlockSize() const = 0;
|
||||||
@ -1567,7 +1604,7 @@ public:
|
|||||||
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
|
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
|
||||||
|
|
||||||
/// \brief Store blocks for reuse for a given request id
|
/// \brief Store blocks for reuse for a given request id
|
||||||
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
|
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
@ -1661,7 +1698,15 @@ public:
|
|||||||
BlockKey const& blockKey, SizeType32 windowSize)
|
BlockKey const& blockKey, SizeType32 windowSize)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
|
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
|
||||||
|
|
||||||
|
//! @brief Get the retention priority of a block by its ID.
|
||||||
|
//! @param blockId The ID of the block.
|
||||||
|
//! @param windowSize The attention window size this block belongs to.
|
||||||
|
//! @return The retention priority of the block, or default priority if block not found.
|
||||||
|
[[nodiscard]] virtual executor::RetentionPriority getPriorityByBlockId(
|
||||||
|
KVCacheBlock::IdType blockId, SizeType32 windowSize) const
|
||||||
|
= 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class KVCacheManager : public BaseKVCacheManager
|
class KVCacheManager : public BaseKVCacheManager
|
||||||
@ -1879,6 +1924,11 @@ public:
|
|||||||
return mEnableBlockReuse;
|
return mEnableBlockReuse;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isEnablePartialReuse() const override
|
||||||
|
{
|
||||||
|
return mBlockManager.isEnablePartialReuse();
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool isEnableIndexerKCache() const override
|
[[nodiscard]] bool isEnableIndexerKCache() const override
|
||||||
{
|
{
|
||||||
return mBlockManager.isEnableIndexerKCache();
|
return mBlockManager.isEnableIndexerKCache();
|
||||||
@ -1922,7 +1972,7 @@ public:
|
|||||||
//! \brief Store newest blocks for reuse
|
//! \brief Store newest blocks for reuse
|
||||||
void storeNewBlock(LlmRequest const& llmRequest) override;
|
void storeNewBlock(LlmRequest const& llmRequest) override;
|
||||||
|
|
||||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
|
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
|
||||||
|
|
||||||
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
||||||
@ -1943,7 +1993,10 @@ public:
|
|||||||
|
|
||||||
void pinBlocks(LlmRequest::RequestIdType requestId) override;
|
void pinBlocks(LlmRequest::RequestIdType requestId) override;
|
||||||
|
|
||||||
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
|
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
|
||||||
|
|
||||||
|
[[nodiscard]] executor::RetentionPriority getPriorityByBlockId(
|
||||||
|
KVCacheBlock::IdType blockId, SizeType32 windowSize) const override;
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
|
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
|
||||||
|
|
||||||
|
|||||||
@ -49,6 +49,8 @@ enum class LlmRequestState : int32_t
|
|||||||
kUNKNOWN = 0, ///< Unknown state
|
kUNKNOWN = 0, ///< Unknown state
|
||||||
kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
|
kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
|
||||||
|
|
||||||
|
kDISAGG_CONTEXT_WAIT_SCHEDULER = 7, ///< Waiting for scheduler to schedule the context-only request
|
||||||
|
/// e.g. in gen-first mode when generation request is not scheduled yet
|
||||||
kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model
|
kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model
|
||||||
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache
|
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache
|
||||||
|
|
||||||
@ -65,6 +67,7 @@ enum class LlmRequestState : int32_t
|
|||||||
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
|
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
|
||||||
/// after computation finished
|
/// after computation finished
|
||||||
kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission.
|
kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission.
|
||||||
|
kDISAGG_GENERATION_WAIT_TOKENS = 23, ///< Generation-only request waiting for ctx/draft tokens to be received
|
||||||
|
|
||||||
// error states
|
// error states
|
||||||
kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission
|
kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission
|
||||||
@ -116,6 +119,7 @@ public:
|
|||||||
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> multimodalHashes = std::nullopt,
|
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> multimodalHashes = std::nullopt,
|
||||||
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalPositions = std::nullopt,
|
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalPositions = std::nullopt,
|
||||||
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalLengths = std::nullopt,
|
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalLengths = std::nullopt,
|
||||||
|
std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>> multimodalUuids = std::nullopt,
|
||||||
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
|
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
|
||||||
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
|
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
|
||||||
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
||||||
@ -165,6 +169,7 @@ public:
|
|||||||
, mMultimodalHashes(std::move(multimodalHashes))
|
, mMultimodalHashes(std::move(multimodalHashes))
|
||||||
, mMultimodalPositions(std::move(multimodalPositions))
|
, mMultimodalPositions(std::move(multimodalPositions))
|
||||||
, mMultimodalLengths(std::move(multimodalLengths))
|
, mMultimodalLengths(std::move(multimodalLengths))
|
||||||
|
, mMultimodalUuids(std::move(multimodalUuids))
|
||||||
, mMultimodalEmbedding(std::move(multimodalEmbedding))
|
, mMultimodalEmbedding(std::move(multimodalEmbedding))
|
||||||
, mMropeRotaryCosSin(std::move(mropeRotaryCosSin))
|
, mMropeRotaryCosSin(std::move(mropeRotaryCosSin))
|
||||||
, mMropePositionDeltas(mropePositionDeltas)
|
, mMropePositionDeltas(mropePositionDeltas)
|
||||||
@ -838,6 +843,20 @@ public:
|
|||||||
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
|
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
|
||||||
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
|
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
|
||||||
: LlmRequestState::kCONTEXT_INIT;
|
: LlmRequestState::kCONTEXT_INIT;
|
||||||
|
|
||||||
|
if (mLlmRequestType == LlmRequestType::LLMREQUEST_TYPE_GENERATION_ONLY)
|
||||||
|
{
|
||||||
|
|
||||||
|
// If gen only server is configured with MAX_UTILIZATION scheduler, the running gen only request may be
|
||||||
|
// paused and rescheduled as context_init state, which will run context phase, degrading performance.
|
||||||
|
// Have no idea how to avoid this. If we modify the max utilization scheduler to avoid pausing
|
||||||
|
// generation-only requests, it could result in no KV cache being available, causing requests to remain
|
||||||
|
// unscheduled indefinitely. We just issue a warning here.
|
||||||
|
TLLM_LOG_WARNING(
|
||||||
|
"Pausing generation-only request, request_id: %lu, changes it to context init state, which may degrade "
|
||||||
|
"performance.",
|
||||||
|
mRequestId);
|
||||||
|
}
|
||||||
mContextCurrentPositionTarget = 0;
|
mContextCurrentPositionTarget = 0;
|
||||||
mContextCurrentPositionDraft = 0;
|
mContextCurrentPositionDraft = 0;
|
||||||
mPrepopulatedPromptLenTarget = 0;
|
mPrepopulatedPromptLenTarget = 0;
|
||||||
@ -892,6 +911,11 @@ public:
|
|||||||
return mMultimodalLengths;
|
return mMultimodalLengths;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>> getMultimodalUuids() const
|
||||||
|
{
|
||||||
|
return mMultimodalUuids;
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] std::optional<TensorPtr> getMultimodalEmbedding() const
|
[[nodiscard]] std::optional<TensorPtr> getMultimodalEmbedding() const
|
||||||
{
|
{
|
||||||
return mMultimodalEmbedding;
|
return mMultimodalEmbedding;
|
||||||
@ -1511,15 +1535,17 @@ public:
|
|||||||
{
|
{
|
||||||
switch (mState)
|
switch (mState)
|
||||||
{
|
{
|
||||||
case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS; break;
|
case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS;
|
||||||
case batch_manager::LlmRequestState::kCONTEXT_INIT: return executor::RequestStage::kCONTEXT_IN_PROGRESS; break;
|
case batch_manager::LlmRequestState::kCONTEXT_INIT:
|
||||||
|
case batch_manager::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER:
|
||||||
|
return executor::RequestStage::kCONTEXT_IN_PROGRESS;
|
||||||
case batch_manager::LlmRequestState::kGENERATION_IN_PROGRESS:
|
case batch_manager::LlmRequestState::kGENERATION_IN_PROGRESS:
|
||||||
case batch_manager::LlmRequestState::kGENERATION_TO_COMPLETE:
|
case batch_manager::LlmRequestState::kGENERATION_TO_COMPLETE:
|
||||||
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE:
|
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE:
|
||||||
case batch_manager::LlmRequestState::kDISAGG_GENERATION_INIT:
|
case batch_manager::LlmRequestState::kDISAGG_GENERATION_INIT:
|
||||||
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS:
|
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS:
|
||||||
|
case batch_manager::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS:
|
||||||
return executor::RequestStage::kGENERATION_IN_PROGRESS;
|
return executor::RequestStage::kGENERATION_IN_PROGRESS;
|
||||||
break;
|
|
||||||
default: TLLM_LOG_ERROR("Unexpected request state."); return executor::RequestStage::kGENERATION_COMPLETE;
|
default: TLLM_LOG_ERROR("Unexpected request state."); return executor::RequestStage::kGENERATION_COMPLETE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1536,8 +1562,14 @@ public:
|
|||||||
|
|
||||||
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
|
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
|
||||||
{
|
{
|
||||||
mContextCurrentPositionDraft = contextCurrentPosition;
|
if (mUseDraftModel)
|
||||||
mContextCurrentPositionTarget = contextCurrentPosition;
|
{
|
||||||
|
mContextCurrentPositionDraft = contextCurrentPosition;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
mContextCurrentPositionTarget = contextCurrentPosition;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
|
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
|
||||||
@ -1667,6 +1699,12 @@ public:
|
|||||||
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
|
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
|
||||||
|
{
|
||||||
|
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
|
||||||
|
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool isTimedOut() const
|
[[nodiscard]] bool isTimedOut() const
|
||||||
{
|
{
|
||||||
if (!mAllottedTimeMs.has_value())
|
if (!mAllottedTimeMs.has_value())
|
||||||
@ -1933,6 +1971,7 @@ protected:
|
|||||||
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> mMultimodalHashes{std::nullopt};
|
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> mMultimodalHashes{std::nullopt};
|
||||||
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalPositions{std::nullopt};
|
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalPositions{std::nullopt};
|
||||||
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalLengths{std::nullopt};
|
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalLengths{std::nullopt};
|
||||||
|
std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>> mMultimodalUuids{std::nullopt};
|
||||||
std::optional<TensorPtr> mMultimodalEmbedding{std::nullopt};
|
std::optional<TensorPtr> mMultimodalEmbedding{std::nullopt};
|
||||||
std::optional<TensorPtr> mMropeRotaryCosSin{std::nullopt};
|
std::optional<TensorPtr> mMropeRotaryCosSin{std::nullopt};
|
||||||
std::optional<SizeType32> mMropePositionDeltas{std::nullopt};
|
std::optional<SizeType32> mMropePositionDeltas{std::nullopt};
|
||||||
@ -2221,6 +2260,7 @@ public:
|
|||||||
std::optional<std::vector<std::vector<SizeType32>>> multimodalHashes = std::nullopt,
|
std::optional<std::vector<std::vector<SizeType32>>> multimodalHashes = std::nullopt,
|
||||||
std::optional<std::vector<SizeType32>> multimodalPositions = std::nullopt,
|
std::optional<std::vector<SizeType32>> multimodalPositions = std::nullopt,
|
||||||
std::optional<std::vector<SizeType32>> multimodalLengths = std::nullopt,
|
std::optional<std::vector<SizeType32>> multimodalLengths = std::nullopt,
|
||||||
|
std::optional<std::vector<std::optional<std::string>>> multimodalUuids = std::nullopt,
|
||||||
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
|
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
|
||||||
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
|
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
|
||||||
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
||||||
@ -2261,6 +2301,9 @@ public:
|
|||||||
multimodalLengths.has_value()
|
multimodalLengths.has_value()
|
||||||
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalLengths.value()))
|
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalLengths.value()))
|
||||||
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
|
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
|
||||||
|
multimodalUuids.has_value()
|
||||||
|
? std::make_shared<std::vector<std::optional<std::string>>>(std::move(multimodalUuids.value()))
|
||||||
|
: std::optional<std::shared_ptr<std::vector<std::optional<std::string>>>>(std::nullopt),
|
||||||
std::move(multimodalEmbedding), std::move(mropeRotaryCosSin), mropePositionDeltas, loraTaskId,
|
std::move(multimodalEmbedding), std::move(mropeRotaryCosSin), mropePositionDeltas, loraTaskId,
|
||||||
std::move(loraWeights), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig),
|
std::move(loraWeights), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig),
|
||||||
returnLogProbs, returnContextLogits, returnGenerationLogits,
|
returnLogProbs, returnContextLogits, returnGenerationLogits,
|
||||||
|
|||||||
@ -57,7 +57,10 @@ class BasePeftCacheManager
|
|||||||
public:
|
public:
|
||||||
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
|
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
|
||||||
using RequestVector = std::vector<LlmRequestPtr>;
|
using RequestVector = std::vector<LlmRequestPtr>;
|
||||||
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
||||||
|
using TaskPeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
||||||
|
using TaskIdToReqIds = std::unordered_map<uint64_t, std::vector<uint64_t>>;
|
||||||
|
using EnsureBatchTaskResult = std::tuple<TaskPeftTable, TaskIdToReqIds>;
|
||||||
|
|
||||||
virtual ~BasePeftCacheManager() = default;
|
virtual ~BasePeftCacheManager() = default;
|
||||||
|
|
||||||
@ -99,6 +102,8 @@ public:
|
|||||||
class PeftCacheManager : public BasePeftCacheManager
|
class PeftCacheManager : public BasePeftCacheManager
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
using EnsureBatchTaskResult = BasePeftCacheManager::EnsureBatchTaskResult;
|
||||||
|
|
||||||
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
|
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
|
||||||
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||||
|
|
||||||
@ -109,12 +114,17 @@ public:
|
|||||||
PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests,
|
PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests,
|
||||||
bool resetGpuCache = false) override;
|
bool resetGpuCache = false) override;
|
||||||
|
|
||||||
|
EnsureBatchTaskResult ensureBatchMapTaskId(
|
||||||
|
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false);
|
||||||
|
|
||||||
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
|
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
|
||||||
|
|
||||||
[[nodiscard]] bool isTaskDone(uint64_t taskId) const;
|
[[nodiscard]] bool isTaskDone(uint64_t taskId) const;
|
||||||
|
|
||||||
[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;
|
[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;
|
||||||
|
|
||||||
|
[[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const;
|
||||||
|
|
||||||
void resetDeviceCache() override;
|
void resetDeviceCache() override;
|
||||||
|
|
||||||
void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;
|
void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;
|
||||||
@ -159,7 +169,7 @@ private:
|
|||||||
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToReqIds;
|
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToReqIds;
|
||||||
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
|
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
|
||||||
|
|
||||||
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
|
std::tuple<std::unordered_map<uint64_t, std::future<void>>, TaskIdToReqIds> getTaskMaps(
|
||||||
RequestVector const& contextRequests, RequestVector const& generationRequests);
|
RequestVector const& contextRequests, RequestVector const& generationRequests);
|
||||||
|
|
||||||
runtime::ModelConfig mModelConfig;
|
runtime::ModelConfig mModelConfig;
|
||||||
|
|||||||
@ -16,11 +16,16 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "tensorrt_llm/batch_manager/common.h"
|
||||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||||
#include "tensorrt_llm/runtime/iTensor.h"
|
#include "tensorrt_llm/runtime/iTensor.h"
|
||||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace tensorrt_llm::batch_manager::rnn_state_manager
|
namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -30,16 +35,34 @@ public:
|
|||||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||||
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
|
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
|
||||||
|
using RequestIdType = tensorrt_llm::batch_manager::RequestIdType;
|
||||||
|
|
||||||
RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig,
|
RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runtime::ModelConfig const& modelConfig,
|
||||||
runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager);
|
runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::BufferManager const& bufferManager);
|
||||||
|
|
||||||
|
RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups, SizeType32 headDim,
|
||||||
|
SizeType32 maxBatchSize, runtime::WorldConfig const& worldConfig, int64_t stream, nvinfer1::DataType dtype,
|
||||||
|
nvinfer1::DataType ssmCacheDtype, std::vector<SizeType32> const& ppLayers);
|
||||||
|
|
||||||
void getPtrBuffers(TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig,
|
void getPtrBuffers(TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig,
|
||||||
runtime::WorldConfig const& worldConfig) const;
|
runtime::WorldConfig const& worldConfig) const;
|
||||||
|
|
||||||
void fillSlotMapping(
|
void fillSlotMapping(
|
||||||
runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
|
runtime::ITensor& dstPointers, SizeType32 dstSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
|
||||||
|
|
||||||
|
void allocateCacheBlocks(std::vector<RequestIdType> const& requestIds);
|
||||||
|
|
||||||
|
void freeCacheBlock(RequestIdType requestId);
|
||||||
|
|
||||||
|
[[nodiscard]] SizeType32 getCacheIndex(RequestIdType requestId) const;
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<SizeType32> getStateIndices(
|
||||||
|
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding);
|
||||||
|
|
||||||
|
[[nodiscard]] TensorPtr getConvStates(SizeType32 layerIdx) const;
|
||||||
|
|
||||||
|
[[nodiscard]] TensorPtr getSsmStates(SizeType32 layerIdx) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states.
|
// If we need support beam search, we may need mMaxBeamWidth + 1 slots and use separate input / output states.
|
||||||
TensorPtr pagedRnnStates; // [local_nb_layers, max_seq_num * max_beam_width, state_size, rnn_hidden_size] or
|
TensorPtr pagedRnnStates; // [local_nb_layers, max_seq_num * max_beam_width, state_size, rnn_hidden_size] or
|
||||||
@ -55,6 +78,10 @@ private:
|
|||||||
SizeType32 mMaxNumSequences = 0;
|
SizeType32 mMaxNumSequences = 0;
|
||||||
SizeType32 mMaxBeamWidth = 0;
|
SizeType32 mMaxBeamWidth = 0;
|
||||||
SizeType32 mBeamSlotsPerSequence = 0;
|
SizeType32 mBeamSlotsPerSequence = 0;
|
||||||
|
std::unordered_map<SizeType32, SizeType32> mLayerOffsets;
|
||||||
|
std::vector<SizeType32> mFreeBlocks;
|
||||||
|
std::unordered_map<RequestIdType, SizeType32> mCacheIndex;
|
||||||
|
std::optional<runtime::BufferManager> mBufferManager;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorrt_llm::batch_manager::rnn_state_manager
|
} // namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||||
|
|||||||
@ -151,26 +151,6 @@ void checkEx(
|
|||||||
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
|
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
|
||||||
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)
|
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)
|
||||||
|
|
||||||
inline std::optional<bool> isCudaLaunchBlocking()
|
|
||||||
{
|
|
||||||
thread_local bool firstCall = true;
|
|
||||||
thread_local std::optional<bool> result = std::nullopt;
|
|
||||||
if (!firstCall)
|
|
||||||
{
|
|
||||||
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
|
||||||
if (env != nullptr && std::string(env) == "1")
|
|
||||||
{
|
|
||||||
result = true;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
result = false;
|
|
||||||
}
|
|
||||||
firstCall = false;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool isCapturing(cudaStream_t stream)
|
inline bool isCapturing(cudaStream_t stream)
|
||||||
{
|
{
|
||||||
cudaStreamCaptureStatus status;
|
cudaStreamCaptureStatus status;
|
||||||
@ -180,21 +160,23 @@ inline bool isCapturing(cudaStream_t stream)
|
|||||||
|
|
||||||
inline bool doCheckError(cudaStream_t stream)
|
inline bool doCheckError(cudaStream_t stream)
|
||||||
{
|
{
|
||||||
auto const cudaLaunchBlocking = isCudaLaunchBlocking();
|
// If we're capturing a CUDA graph we don't check. Otherwise, we
|
||||||
if (cudaLaunchBlocking.has_value() && cudaLaunchBlocking.value())
|
// default to only checking in debug builds. But we always listen to
|
||||||
|
// the env variable.
|
||||||
|
static bool const doCheckIfNotCapturing = []()
|
||||||
{
|
{
|
||||||
return !isCapturing(stream);
|
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
||||||
}
|
if (env != nullptr)
|
||||||
|
{
|
||||||
|
return std::string(env) == "1";
|
||||||
|
}
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
// Debug builds will sync when we're not capturing unless explicitly
|
return true;
|
||||||
// disabled.
|
|
||||||
bool const checkError = cudaLaunchBlocking.value_or(!isCapturing(stream));
|
|
||||||
#else
|
#else
|
||||||
bool const checkError = cudaLaunchBlocking.value_or(false);
|
return false;
|
||||||
#endif
|
#endif
|
||||||
|
}();
|
||||||
return checkError;
|
return doCheckIfNotCapturing && !isCapturing(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void syncAndCheck(cudaStream_t stream, char const* const file, int const line)
|
inline void syncAndCheck(cudaStream_t stream, char const* const file, int const line)
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "tensorrt_llm/executor/serialization.h"
|
#include "tensorrt_llm/executor/serialization.h"
|
||||||
|
#include <atomic>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace tensorrt_llm::executor::kv_cache
|
namespace tensorrt_llm::executor::kv_cache
|
||||||
@ -27,8 +28,9 @@ class CommState;
|
|||||||
struct DataContext
|
struct DataContext
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
explicit DataContext(int tag)
|
explicit DataContext(int tag, std::atomic<bool> const& transferTerminate = sDefaultTransferTerminate)
|
||||||
: mTag{tag}
|
: mTag{tag}
|
||||||
|
, mTransferTerminate(transferTerminate)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,8 +39,15 @@ public:
|
|||||||
return mTag;
|
return mTag;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::atomic<bool> const& getTransferTerminate() const noexcept
|
||||||
|
{
|
||||||
|
return mTransferTerminate;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
inline static std::atomic<bool> sDefaultTransferTerminate{false};
|
||||||
int const mTag;
|
int const mTag;
|
||||||
|
std::atomic<bool> const& mTransferTerminate;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Connection
|
class Connection
|
||||||
|
|||||||
@ -51,7 +51,8 @@ public:
|
|||||||
CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig,
|
CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig,
|
||||||
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
|
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
|
||||||
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false,
|
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false,
|
||||||
bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
|
bool enablePartialReuse = false, bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0,
|
||||||
|
SizeType32 indexerKCacheQuantBlockSize = 128)
|
||||||
: mModelConfig(std::move(modelConfig))
|
: mModelConfig(std::move(modelConfig))
|
||||||
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
|
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
|
||||||
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
|
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
|
||||||
@ -60,6 +61,7 @@ public:
|
|||||||
, mAttentionConfig(attentionType, kvFactor)
|
, mAttentionConfig(attentionType, kvFactor)
|
||||||
{
|
{
|
||||||
mEnableBlockReuse = enableBlockReuse;
|
mEnableBlockReuse = enableBlockReuse;
|
||||||
|
mEnablePartialReuse = enablePartialReuse;
|
||||||
mHasIndexerKCache = hasIndexerKCache;
|
mHasIndexerKCache = hasIndexerKCache;
|
||||||
mIndexerDimPerHead = indexerDimPerHead;
|
mIndexerDimPerHead = indexerDimPerHead;
|
||||||
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
|
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
|
||||||
@ -69,8 +71,8 @@ public:
|
|||||||
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
|
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
|
||||||
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
|
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
|
||||||
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
|
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
|
||||||
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false,
|
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool enablePartialReuse = false,
|
||||||
SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
|
bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
|
||||||
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
|
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
|
||||||
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
|
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
|
||||||
attentionLayerNumPerPP}
|
attentionLayerNumPerPP}
|
||||||
@ -78,6 +80,7 @@ public:
|
|||||||
, mAttentionConfig(attentionType, kvFactor)
|
, mAttentionConfig(attentionType, kvFactor)
|
||||||
{
|
{
|
||||||
mEnableBlockReuse = enableBlockReuse;
|
mEnableBlockReuse = enableBlockReuse;
|
||||||
|
mEnablePartialReuse = enablePartialReuse;
|
||||||
mHasIndexerKCache = hasIndexerKCache;
|
mHasIndexerKCache = hasIndexerKCache;
|
||||||
mIndexerDimPerHead = indexerDimPerHead;
|
mIndexerDimPerHead = indexerDimPerHead;
|
||||||
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
|
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
|
||||||
@ -87,8 +90,8 @@ public:
|
|||||||
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
|
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
|
||||||
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
|
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
|
||||||
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
|
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
|
||||||
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false,
|
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool enablePartialReuse = false,
|
||||||
SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
|
bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
|
||||||
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
|
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
|
||||||
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
|
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
|
||||||
attentionLayerNumPerPP}
|
attentionLayerNumPerPP}
|
||||||
@ -96,6 +99,7 @@ public:
|
|||||||
, mAttentionConfig(attentionType, kvFactor)
|
, mAttentionConfig(attentionType, kvFactor)
|
||||||
{
|
{
|
||||||
mEnableBlockReuse = enableBlockReuse;
|
mEnableBlockReuse = enableBlockReuse;
|
||||||
|
mEnablePartialReuse = enablePartialReuse;
|
||||||
mHasIndexerKCache = hasIndexerKCache;
|
mHasIndexerKCache = hasIndexerKCache;
|
||||||
mIndexerDimPerHead = indexerDimPerHead;
|
mIndexerDimPerHead = indexerDimPerHead;
|
||||||
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
|
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
|
||||||
@ -186,6 +190,11 @@ public:
|
|||||||
return mEnableBlockReuse;
|
return mEnableBlockReuse;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool getEnablePartialReuse() const
|
||||||
|
{
|
||||||
|
return mEnablePartialReuse;
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool getHasIndexerKCache() const
|
[[nodiscard]] bool getHasIndexerKCache() const
|
||||||
{
|
{
|
||||||
return mHasIndexerKCache;
|
return mHasIndexerKCache;
|
||||||
@ -221,6 +230,7 @@ public:
|
|||||||
sstring << "dpRank:" << mParallelConfig.mDPrank << "\n";
|
sstring << "dpRank:" << mParallelConfig.mDPrank << "\n";
|
||||||
sstring << "dpSize:" << mParallelConfig.mDPsize << "\n";
|
sstring << "dpSize:" << mParallelConfig.mDPsize << "\n";
|
||||||
sstring << "enableBlockReuse:" << mEnableBlockReuse << "\n";
|
sstring << "enableBlockReuse:" << mEnableBlockReuse << "\n";
|
||||||
|
sstring << "enablePartialReuse:" << mEnablePartialReuse << "\n";
|
||||||
sstring << "hasIndexerKCache:" << mHasIndexerKCache << "\n";
|
sstring << "hasIndexerKCache:" << mHasIndexerKCache << "\n";
|
||||||
sstring << "indexerDimPerHead:" << mIndexerDimPerHead << "\n";
|
sstring << "indexerDimPerHead:" << mIndexerDimPerHead << "\n";
|
||||||
sstring << "indexerKCacheQuantBlockSize:" << mIndexerKCacheQuantBlockSize << "\n";
|
sstring << "indexerKCacheQuantBlockSize:" << mIndexerKCacheQuantBlockSize << "\n";
|
||||||
@ -234,6 +244,7 @@ private:
|
|||||||
nvinfer1::DataType mDataType;
|
nvinfer1::DataType mDataType;
|
||||||
AttentionConfig mAttentionConfig;
|
AttentionConfig mAttentionConfig;
|
||||||
bool mEnableBlockReuse{false};
|
bool mEnableBlockReuse{false};
|
||||||
|
bool mEnablePartialReuse{false};
|
||||||
bool mHasIndexerKCache{false};
|
bool mHasIndexerKCache{false};
|
||||||
SizeType32 mIndexerDimPerHead{0};
|
SizeType32 mIndexerDimPerHead{0};
|
||||||
SizeType32 mIndexerKCacheQuantBlockSize{128};
|
SizeType32 mIndexerKCacheQuantBlockSize{128};
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -48,10 +48,6 @@ namespace tensorrt_llm::executor
|
|||||||
{
|
{
|
||||||
|
|
||||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||||
// Mmkey is used in KVCacheBlock when multimodal data presents in a block.
|
|
||||||
// Type alias for hash array + start offset at per-block granularity.
|
|
||||||
// This differs from the per-request level multimodal hash in MultimodalInput.
|
|
||||||
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
|
|
||||||
|
|
||||||
/// @brief Version of TRT-LLM
|
/// @brief Version of TRT-LLM
|
||||||
char const* version() noexcept;
|
char const* version() noexcept;
|
||||||
@ -301,11 +297,13 @@ class MultimodalInput
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
explicit MultimodalInput(std::vector<std::vector<SizeType32>> multimodalHashes,
|
explicit MultimodalInput(std::vector<std::vector<SizeType32>> multimodalHashes,
|
||||||
std::vector<SizeType32> multimodalPositions, std::vector<SizeType32> multimodalLengths);
|
std::vector<SizeType32> multimodalPositions, std::vector<SizeType32> multimodalLengths,
|
||||||
|
std::optional<std::vector<std::optional<std::string>>> multimodalUuids = std::nullopt);
|
||||||
|
|
||||||
[[nodiscard]] std::vector<std::vector<SizeType32>> getMultimodalHashes() const;
|
[[nodiscard]] std::vector<std::vector<SizeType32>> getMultimodalHashes() const;
|
||||||
[[nodiscard]] std::vector<SizeType32> getMultimodalPositions() const;
|
[[nodiscard]] std::vector<SizeType32> getMultimodalPositions() const;
|
||||||
[[nodiscard]] std::vector<SizeType32> getMultimodalLengths() const;
|
[[nodiscard]] std::vector<SizeType32> getMultimodalLengths() const;
|
||||||
|
[[nodiscard]] std::optional<std::vector<std::optional<std::string>>> const& getMultimodalUuids() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Serialization;
|
friend class Serialization;
|
||||||
@ -315,6 +313,9 @@ private:
|
|||||||
std::vector<SizeType32> mMultimodalPositions;
|
std::vector<SizeType32> mMultimodalPositions;
|
||||||
/// @brief The multimodal lengths
|
/// @brief The multimodal lengths
|
||||||
std::vector<SizeType32> mMultimodalLengths;
|
std::vector<SizeType32> mMultimodalLengths;
|
||||||
|
/// @brief Optional user-provided UUIDs for multimodal items.
|
||||||
|
/// When provided, these are returned in KV cache events instead of content hashes.
|
||||||
|
std::optional<std::vector<std::optional<std::string>>> mMultimodalUuids;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// @brief Configuration for mrope
|
/// @brief Configuration for mrope
|
||||||
@ -442,11 +443,15 @@ class ContextPhaseParams
|
|||||||
public:
|
public:
|
||||||
using RequestIdType = std::uint64_t;
|
using RequestIdType = std::uint64_t;
|
||||||
|
|
||||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens);
|
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens,
|
||||||
ContextPhaseParams(
|
std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||||
VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens);
|
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||||
|
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens,
|
||||||
|
std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||||
|
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector<char> const& serializedState,
|
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector<char> const& serializedState,
|
||||||
std::optional<VecTokens> draftTokens);
|
std::optional<VecTokens> draftTokens, std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||||
|
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||||
|
|
||||||
ContextPhaseParams(ContextPhaseParams const&);
|
ContextPhaseParams(ContextPhaseParams const&);
|
||||||
ContextPhaseParams(ContextPhaseParams&&) noexcept;
|
ContextPhaseParams(ContextPhaseParams&&) noexcept;
|
||||||
@ -457,15 +462,22 @@ public:
|
|||||||
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
|
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
|
||||||
|
|
||||||
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
|
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
|
||||||
|
void setFirstGenTokens(VecTokens const& firstGenTokens) noexcept;
|
||||||
[[nodiscard]] std::optional<VecTokens> const& getDraftTokens() const& noexcept;
|
[[nodiscard]] std::optional<VecTokens> const& getDraftTokens() const& noexcept;
|
||||||
|
void setDraftTokens(std::optional<VecTokens> const& draftTokens) noexcept;
|
||||||
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
|
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
|
||||||
[[nodiscard]] RequestIdType getReqId() const noexcept;
|
[[nodiscard]] RequestIdType getReqId() const noexcept;
|
||||||
|
void setReqId(RequestIdType const& reqId) noexcept;
|
||||||
[[nodiscard]] void const* getState() const noexcept;
|
[[nodiscard]] void const* getState() const noexcept;
|
||||||
[[nodiscard]] void* getState() noexcept;
|
[[nodiscard]] void* getState() noexcept;
|
||||||
[[nodiscard]] void* releaseState() noexcept;
|
[[nodiscard]] void* releaseState() noexcept;
|
||||||
[[nodiscard]] std::vector<char> getSerializedState() const noexcept;
|
[[nodiscard]] std::vector<char> getSerializedState() const noexcept;
|
||||||
|
|
||||||
|
[[nodiscard]] std::optional<SizeType32> getCtxDpRank() const noexcept;
|
||||||
|
void setCtxDpRank(std::optional<SizeType32> const& ctxDpRank) noexcept;
|
||||||
|
[[nodiscard]] std::optional<std::string> const& getDisaggInfoEndpoint() const noexcept;
|
||||||
|
void setDisaggInfoEndpoint(std::optional<std::string> const& disaggInfoEndpoint) noexcept;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Serialization;
|
friend class Serialization;
|
||||||
static void deleter(void const* data);
|
static void deleter(void const* data);
|
||||||
@ -482,6 +494,12 @@ private:
|
|||||||
|
|
||||||
/// @brief The draft tokens generated by context executor
|
/// @brief The draft tokens generated by context executor
|
||||||
std::optional<VecTokens> mDraftTokens;
|
std::optional<VecTokens> mDraftTokens;
|
||||||
|
|
||||||
|
/// @brief The context phase data parallel rank
|
||||||
|
std::optional<SizeType32> mCtxDpRank;
|
||||||
|
|
||||||
|
/// @brief The disaggregated info endpoint
|
||||||
|
std::optional<std::string> mDisaggInfoEndpoint;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// @brief Configuration for speculative decoding (both draft and target models)
|
/// @brief Configuration for speculative decoding (both draft and target models)
|
||||||
@ -684,6 +702,7 @@ public:
|
|||||||
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
|
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
|
||||||
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
|
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
|
||||||
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
|
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
|
||||||
|
/// @param disaggRequestId Disaggregated request ID.
|
||||||
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
|
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
|
||||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||||
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
||||||
@ -711,7 +730,8 @@ public:
|
|||||||
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);
|
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
|
||||||
|
std::optional<IdType> disaggRequestId = std::nullopt);
|
||||||
|
|
||||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||||
@ -761,6 +781,7 @@ public:
|
|||||||
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
|
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
|
||||||
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
|
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
|
||||||
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
|
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
|
||||||
|
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const;
|
||||||
|
|
||||||
void setStreaming(bool streaming);
|
void setStreaming(bool streaming);
|
||||||
void setSamplingConfig(SamplingConfig const& config);
|
void setSamplingConfig(SamplingConfig const& config);
|
||||||
@ -796,6 +817,7 @@ public:
|
|||||||
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
|
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
|
||||||
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
|
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
|
||||||
void setCacheSaltID(CacheSaltIDType cacheSaltID);
|
void setCacheSaltID(CacheSaltIDType cacheSaltID);
|
||||||
|
void setDisaggRequestId(IdType disaggRequestId);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Serialization;
|
friend class Serialization;
|
||||||
@ -1468,7 +1490,8 @@ public:
|
|||||||
DEFAULT = 0,
|
DEFAULT = 0,
|
||||||
MPI = 1,
|
MPI = 1,
|
||||||
UCX = 2,
|
UCX = 2,
|
||||||
NIXL = 3
|
NIXL = 3,
|
||||||
|
MOONCAKE = 4
|
||||||
};
|
};
|
||||||
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
|
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
|
||||||
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
|
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
|
||||||
|
|||||||
@ -349,6 +349,11 @@ public:
|
|||||||
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
|
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
|
||||||
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);
|
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);
|
||||||
|
|
||||||
|
// MmKey
|
||||||
|
[[nodiscard]] static size_t serializedSize(MmKey const& key);
|
||||||
|
static void serialize(MmKey const& key, std::ostream& os);
|
||||||
|
[[nodiscard]] static MmKey deserializeMmKey(std::istream& is);
|
||||||
|
|
||||||
// UniqueToken
|
// UniqueToken
|
||||||
[[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token);
|
[[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token);
|
||||||
static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os);
|
static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os);
|
||||||
|
|||||||
@ -274,13 +274,20 @@ private:
|
|||||||
std::optional<SyncMessage> mSyncMessage;
|
std::optional<SyncMessage> mSyncMessage;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class TransferState : uint8_t
|
||||||
|
{
|
||||||
|
kIN_PROGRESS,
|
||||||
|
kSUCCESS,
|
||||||
|
kFAILURE,
|
||||||
|
};
|
||||||
|
|
||||||
// Data structure for checking the status of active transfer operations.
|
// Data structure for checking the status of active transfer operations.
|
||||||
class TransferStatus
|
class TransferStatus
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
virtual ~TransferStatus() = default;
|
virtual ~TransferStatus() = default;
|
||||||
[[nodiscard]] virtual bool isCompleted() const = 0;
|
[[nodiscard]] virtual bool isCompleted() const = 0;
|
||||||
virtual void wait() const = 0;
|
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BaseAgentConfig
|
struct BaseAgentConfig
|
||||||
@ -288,6 +295,9 @@ struct BaseAgentConfig
|
|||||||
std::string mName;
|
std::string mName;
|
||||||
bool useProgThread;
|
bool useProgThread;
|
||||||
bool multiThread;
|
bool multiThread;
|
||||||
|
bool useListenThread;
|
||||||
|
bool enableTelemetry;
|
||||||
|
std::unordered_map<std::string, std::string> backendParams;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseTransferAgent
|
class BaseTransferAgent
|
||||||
@ -391,6 +401,14 @@ template <typename... Args>
|
|||||||
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
|
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
|
||||||
return func(std::forward<Args>(args)...);
|
return func(std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
if (backend == "mooncake")
|
||||||
|
{
|
||||||
|
auto& loader = DynLibLoader::getInstance();
|
||||||
|
using CreateMooncakeFuncType = std::unique_ptr<BaseTransferAgent> (*)(BaseAgentConfig const*);
|
||||||
|
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
|
||||||
|
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
|
||||||
|
return func(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
TLLM_THROW("Unknown backend name.");
|
TLLM_THROW("Unknown backend name.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
@ -70,6 +71,29 @@ using EagleChoices = std::vector<std::vector<SizeType32>>;
|
|||||||
using PriorityType = float;
|
using PriorityType = float;
|
||||||
using BufferView = std::basic_string_view<uint8_t>;
|
using BufferView = std::basic_string_view<uint8_t>;
|
||||||
|
|
||||||
|
//! MmKey is used in KVCacheBlock when multimodal data presents in a block.
|
||||||
|
//! Hash is a 32-byte array; startOffset is the per-block token offset; uuid is optional.
|
||||||
|
struct MmKey
|
||||||
|
{
|
||||||
|
std::array<uint8_t, 32> hash;
|
||||||
|
SizeType32 startOffset{};
|
||||||
|
std::optional<std::string> uuid{std::nullopt};
|
||||||
|
|
||||||
|
MmKey() = default;
|
||||||
|
|
||||||
|
MmKey(std::array<uint8_t, 32> hash, SizeType32 startOffset, std::optional<std::string> uuid = std::nullopt)
|
||||||
|
: hash(std::move(hash))
|
||||||
|
, startOffset(startOffset)
|
||||||
|
, uuid(std::move(uuid))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(MmKey const& other) const noexcept
|
||||||
|
{
|
||||||
|
return hash == other.hash && startOffset == other.startOffset && uuid == other.uuid;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
enum class DataType
|
enum class DataType
|
||||||
{
|
{
|
||||||
kBOOL,
|
kBOOL,
|
||||||
|
|||||||
@ -104,12 +104,14 @@ public:
|
|||||||
|
|
||||||
[[nodiscard]] SizeType32 constexpr getTensorParallelRank() const noexcept
|
[[nodiscard]] SizeType32 constexpr getTensorParallelRank() const noexcept
|
||||||
{
|
{
|
||||||
return mRank % mTensorParallelism;
|
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
|
||||||
|
return (mRank % (mTensorParallelism * mContextParallelism)) / mContextParallelism;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] SizeType32 constexpr getContextParallelRank() const noexcept
|
[[nodiscard]] SizeType32 constexpr getContextParallelRank() const noexcept
|
||||||
{
|
{
|
||||||
return (mRank % (mTensorParallelism * mContextParallelism)) / mTensorParallelism;
|
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
|
||||||
|
return mRank % mContextParallelism;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept
|
[[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept
|
||||||
|
|||||||
@ -69,6 +69,11 @@ PREPROCESSOR_FLAGS += -DUSE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE
|
|||||||
# Do we want to use half accumulation for flash attention
|
# Do we want to use half accumulation for flash attention
|
||||||
PREPROCESSOR_FLAGS += -DHALF_ACCUMULATION_FOR_FLASH_ATTENTION
|
PREPROCESSOR_FLAGS += -DHALF_ACCUMULATION_FOR_FLASH_ATTENTION
|
||||||
|
|
||||||
|
# Print the resulted sparsity given threshold in Skip-Softmax attention
|
||||||
|
# Note: You only need to "python scripts/build_wheel.py -D SKIP_SOFTMAX_STAT=ON ..." to use it inside TRTLLM.
|
||||||
|
# Turn this on manually only if you want to build&run the unittest (bin/fmha.exe) with SKIP_SOFTMAX_STAT.
|
||||||
|
# PREPROCESSOR_FLAGS += -DSKIP_SOFTMAX_STAT
|
||||||
|
|
||||||
# Add FLAGS when generating cubins.
|
# Add FLAGS when generating cubins.
|
||||||
ifdef GENERATE_CUBIN
|
ifdef GENERATE_CUBIN
|
||||||
PREPROCESSOR_FLAGS += -DGENERATE_CUBIN
|
PREPROCESSOR_FLAGS += -DGENERATE_CUBIN
|
||||||
|
|||||||
@ -6,6 +6,7 @@ markers =
|
|||||||
fmhca
|
fmhca
|
||||||
debug
|
debug
|
||||||
bench
|
bench
|
||||||
|
needs_l40s
|
||||||
# bin: unit tests
|
# bin: unit tests
|
||||||
# test: python script for invoking fmha.exe
|
# test: python script for invoking fmha.exe
|
||||||
testpaths = bin test
|
testpaths = bin test
|
||||||
|
|||||||
@ -154,7 +154,9 @@ spec_fields = (
|
|||||||
'head_size_v',
|
'head_size_v',
|
||||||
'sage_block_sizes',
|
'sage_block_sizes',
|
||||||
'output_dtype',
|
'output_dtype',
|
||||||
'is_mtp')
|
'is_mtp',
|
||||||
|
'enable_skip_softmax',
|
||||||
|
)
|
||||||
kernel_spec = namedtuple('kernel_spec', spec_fields)
|
kernel_spec = namedtuple('kernel_spec', spec_fields)
|
||||||
kernel_spec.__new__.__defaults__ = (
|
kernel_spec.__new__.__defaults__ = (
|
||||||
1, # ctas_per_head
|
1, # ctas_per_head
|
||||||
@ -179,7 +181,9 @@ kernel_spec.__new__.__defaults__ = (
|
|||||||
0, # head size of V
|
0, # head size of V
|
||||||
None, # sage_block_sizes
|
None, # sage_block_sizes
|
||||||
None, # output_dtype, same as dtype by default.
|
None, # output_dtype, same as dtype by default.
|
||||||
False) # use MTP or not
|
False, # use MTP or not
|
||||||
|
False, # enable skip softmax
|
||||||
|
)
|
||||||
|
|
||||||
generate_cu_trtllm = os.environ.get('GENERATE_CU_TRTLLM',
|
generate_cu_trtllm = os.environ.get('GENERATE_CU_TRTLLM',
|
||||||
'False').lower() == 'true'
|
'False').lower() == 'true'
|
||||||
@ -1435,6 +1439,7 @@ using Ktraits = {kernel_traits_header}
|
|||||||
USE_TMA_STORE,
|
USE_TMA_STORE,
|
||||||
{enable_attn_logit_softcapping_flag},
|
{enable_attn_logit_softcapping_flag},
|
||||||
{return_softmax_stats_flag},
|
{return_softmax_stats_flag},
|
||||||
|
{enable_skip_softmax_flag},
|
||||||
{output_dtype_},
|
{output_dtype_},
|
||||||
{sage_block_size_q},
|
{sage_block_size_q},
|
||||||
{sage_block_size_k},
|
{sage_block_size_k},
|
||||||
@ -1458,6 +1463,7 @@ using Ktraits_causal = {kernel_traits_header}
|
|||||||
USE_TMA_STORE,
|
USE_TMA_STORE,
|
||||||
{enable_attn_logit_softcapping_flag},
|
{enable_attn_logit_softcapping_flag},
|
||||||
{return_softmax_stats_flag},
|
{return_softmax_stats_flag},
|
||||||
|
{enable_skip_softmax_flag},
|
||||||
{output_dtype_}>;
|
{output_dtype_}>;
|
||||||
|
|
||||||
using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
|
using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
|
||||||
@ -1478,6 +1484,7 @@ using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
|
|||||||
USE_TMA_STORE && false,
|
USE_TMA_STORE && false,
|
||||||
{enable_attn_logit_softcapping_flag},
|
{enable_attn_logit_softcapping_flag},
|
||||||
{return_softmax_stats_flag},
|
{return_softmax_stats_flag},
|
||||||
|
{enable_skip_softmax_flag},
|
||||||
{output_dtype_}>;
|
{output_dtype_}>;
|
||||||
|
|
||||||
using Ktraits_custom_mask = {kernel_traits_header}
|
using Ktraits_custom_mask = {kernel_traits_header}
|
||||||
@ -1498,6 +1505,7 @@ using Ktraits_custom_mask = {kernel_traits_header}
|
|||||||
USE_TMA_STORE && false,
|
USE_TMA_STORE && false,
|
||||||
{enable_attn_logit_softcapping_flag},
|
{enable_attn_logit_softcapping_flag},
|
||||||
{return_softmax_stats_flag},
|
{return_softmax_stats_flag},
|
||||||
|
{enable_skip_softmax_flag},
|
||||||
{output_dtype_}>;
|
{output_dtype_}>;
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -1835,6 +1843,8 @@ def encode_name(kernel_spec):
|
|||||||
|
|
||||||
if kernel_spec.enable_attn_logit_softcapping:
|
if kernel_spec.enable_attn_logit_softcapping:
|
||||||
feature_tags += '_softcapping'
|
feature_tags += '_softcapping'
|
||||||
|
if kernel_spec.enable_skip_softmax:
|
||||||
|
feature_tags += '_skipSoftmax'
|
||||||
if kernel_spec.sage_block_sizes:
|
if kernel_spec.sage_block_sizes:
|
||||||
feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}"
|
feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}"
|
||||||
if kernel_spec.output_dtype:
|
if kernel_spec.output_dtype:
|
||||||
@ -2131,6 +2141,8 @@ def get_kernel_code(kspec, kname, lname):
|
|||||||
|
|
||||||
return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats]
|
return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats]
|
||||||
|
|
||||||
|
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
|
||||||
|
|
||||||
# needed by warpspec kernels.
|
# needed by warpspec kernels.
|
||||||
fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
|
fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
|
||||||
kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
|
kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
|
||||||
@ -2331,6 +2343,8 @@ def get_api_code(specs_names):
|
|||||||
f'&& sage_block_size_k == {sage_block_size_k} ' \
|
f'&& sage_block_size_k == {sage_block_size_k} ' \
|
||||||
f'&& sage_block_size_v == {sage_block_size_v} '
|
f'&& sage_block_size_v == {sage_block_size_v} '
|
||||||
|
|
||||||
|
il_check += '&& enable_skip_softmax ' if kspec.enable_skip_softmax else '&& !enable_skip_softmax '
|
||||||
|
|
||||||
il_check += '&& params.use_int8_scale_max ' if kspec.has_scale_max else '&& !params.use_int8_scale_max '
|
il_check += '&& params.use_int8_scale_max ' if kspec.has_scale_max else '&& !params.use_int8_scale_max '
|
||||||
|
|
||||||
slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0
|
slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0
|
||||||
@ -2607,6 +2621,7 @@ const bool warp_specialization = launch_params.warp_specialization
|
|||||||
const bool use_tma = launch_params.use_tma;
|
const bool use_tma = launch_params.use_tma;
|
||||||
const bool use_flash_attention = launch_params.flash_attention;
|
const bool use_flash_attention = launch_params.flash_attention;
|
||||||
const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping;
|
const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping;
|
||||||
|
const bool enable_skip_softmax = launch_params.enable_skip_softmax;
|
||||||
const int attention_input_layout = static_cast<int>(launch_params.attention_input_layout);
|
const int attention_input_layout = static_cast<int>(launch_params.attention_input_layout);
|
||||||
// tiled variant uses ldgsts
|
// tiled variant uses ldgsts
|
||||||
const bool use_tiled = launch_params.use_granular_tiling;
|
const bool use_tiled = launch_params.use_granular_tiling;
|
||||||
@ -2785,6 +2800,8 @@ def get_kernel_traits_code(specs_names):
|
|||||||
enable_attn_logit_softcapping_flag = pythonBoolean2cpp[
|
enable_attn_logit_softcapping_flag = pythonBoolean2cpp[
|
||||||
kspec.enable_attn_logit_softcapping]
|
kspec.enable_attn_logit_softcapping]
|
||||||
|
|
||||||
|
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
|
||||||
|
|
||||||
tmp = dict(locals(), **kspec._asdict())
|
tmp = dict(locals(), **kspec._asdict())
|
||||||
|
|
||||||
if effective_sm < 90:
|
if effective_sm < 90:
|
||||||
@ -2903,7 +2920,8 @@ def get_kernel_traits_code(specs_names):
|
|||||||
{input_layout_flag},
|
{input_layout_flag},
|
||||||
__use_tma_store__ /* USE_TMA_STORE */,
|
__use_tma_store__ /* USE_TMA_STORE */,
|
||||||
{enable_attn_logit_softcapping_flag},
|
{enable_attn_logit_softcapping_flag},
|
||||||
{return_softmax_stats_flag}>;
|
{return_softmax_stats_flag},
|
||||||
|
{enable_skip_softmax_flag}>;
|
||||||
|
|
||||||
printf("%s %d %d %s %d %d\\n",
|
printf("%s %d %d %s %d %d\\n",
|
||||||
\"{kname}\",
|
\"{kname}\",
|
||||||
@ -3062,9 +3080,16 @@ def get_kernel_traits_code(specs_names):
|
|||||||
# For now:
|
# For now:
|
||||||
# 1. Hopper head_size 128 kernel uses cubins for performance regressions.
|
# 1. Hopper head_size 128 kernel uses cubins for performance regressions.
|
||||||
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
|
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
|
||||||
|
# 3. For skip-softmax attention feature, we force not to use cubins.
|
||||||
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
|
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
|
||||||
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
|
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
|
||||||
def use_cubin_header(sm, head_size, dtype, output_dtype=None):
|
def use_cubin_header(sm,
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
output_dtype=None,
|
||||||
|
enable_skip_softmax=False):
|
||||||
|
if enable_skip_softmax:
|
||||||
|
return False
|
||||||
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
|
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
|
||||||
return False
|
return False
|
||||||
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
|
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
|
||||||
@ -3079,7 +3104,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
launchers_dict = {}
|
launchers_dict = {}
|
||||||
for kspec, fname, lname, kname in specs_names:
|
for kspec, fname, lname, kname in specs_names:
|
||||||
if generate_cu_trtllm and not use_cubin_header(
|
if generate_cu_trtllm and not use_cubin_header(
|
||||||
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype):
|
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype,
|
||||||
|
kspec.enable_skip_softmax):
|
||||||
continue
|
continue
|
||||||
name = fname.replace('.', '_')
|
name = fname.replace('.', '_')
|
||||||
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
|
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
|
||||||
@ -3111,8 +3137,9 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
'q_kv_', '').replace('q_paged_kv_', '').replace(
|
'q_kv_', '').replace('q_paged_kv_', '').replace(
|
||||||
'q_k_v_', '').replace('ws_', '').replace(
|
'q_k_v_', '').replace('ws_', '').replace(
|
||||||
'softcapping_',
|
'softcapping_',
|
||||||
'').replace('sage_',
|
'').replace('sage_', '').replace(
|
||||||
'').replace('output_', ''))
|
'skipSoftmax_',
|
||||||
|
'').replace('output_', ''))
|
||||||
flash_attention = 'flash_attention' in kname
|
flash_attention = 'flash_attention' in kname
|
||||||
warp_specialization = 'tma_ws' in kname
|
warp_specialization = 'tma_ws' in kname
|
||||||
toks = tname.split('_')
|
toks = tname.split('_')
|
||||||
@ -3209,6 +3236,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
return_softmax_stats_flag = pythonBoolean2cpp[sm != '90' or (
|
return_softmax_stats_flag = pythonBoolean2cpp[sm != '90' or (
|
||||||
sm == '90' and '_softmax' in kname)]
|
sm == '90' and '_softmax' in kname)]
|
||||||
|
|
||||||
|
enable_skip_softmax_flag = pythonBoolean2cpp['_skipSoftmax' in kname]
|
||||||
|
|
||||||
# meta_unroll_step
|
# meta_unroll_step
|
||||||
meta_unroll_step = unroll_step if ('_nl' in kname
|
meta_unroll_step = unroll_step if ('_nl' in kname
|
||||||
or '_ws' in kname) else '0'
|
or '_ws' in kname) else '0'
|
||||||
@ -3235,7 +3264,8 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
|
|
||||||
def get_lname_from_kname(kname: str) -> str:
|
def get_lname_from_kname(kname: str) -> str:
|
||||||
if use_cubin_header(int(sm), int(head_size), prec.lower(),
|
if use_cubin_header(int(sm), int(head_size), prec.lower(),
|
||||||
output_prec.lower()):
|
output_prec.lower(),
|
||||||
|
enable_skip_softmax_flag):
|
||||||
return 'nullptr'
|
return 'nullptr'
|
||||||
lname = kname.replace('_kernel', '')
|
lname = kname.replace('_kernel', '')
|
||||||
mask_types = [
|
mask_types = [
|
||||||
@ -3253,15 +3283,15 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
||||||
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
|
||||||
'''.format(**locals()) if use_cubin_header(int(sm),
|
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
|
||||||
int(head_size), prec.lower(),
|
prec.lower(), output_prec.lower(),
|
||||||
output_prec.lower()) else '''\
|
enable_skip_softmax_flag) else '''\
|
||||||
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
||||||
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
|
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
|
||||||
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
|
||||||
'''.format(**locals())
|
'''.format(**locals())
|
||||||
else:
|
else:
|
||||||
code = '''\
|
code = '''\
|
||||||
@ -3269,7 +3299,7 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
||||||
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}}}\
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}}}\
|
||||||
'''.format(**locals())
|
'''.format(**locals())
|
||||||
if sm in metadata_v2_dict:
|
if sm in metadata_v2_dict:
|
||||||
metadata_v2_dict[sm].append(code)
|
metadata_v2_dict[sm].append(code)
|
||||||
@ -3377,7 +3407,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
|||||||
bool mAlibiSupported;
|
bool mAlibiSupported;
|
||||||
bool mTiled;
|
bool mTiled;
|
||||||
bool mEnableAttnLogitSoftcapping;
|
bool mEnableAttnLogitSoftcapping;
|
||||||
bool mReturnSoftmaxStats;{launcher_line}
|
bool mReturnSoftmaxStats;
|
||||||
|
bool mEnableSkipSoftmax;{launcher_line}
|
||||||
}} sMhaKernelMetaInfosV2[] = {{
|
}} sMhaKernelMetaInfosV2[] = {{
|
||||||
{metadata_v2}
|
{metadata_v2}
|
||||||
}};
|
}};
|
||||||
@ -3438,6 +3469,7 @@ static const struct TestMetaV2
|
|||||||
bool mTiled;
|
bool mTiled;
|
||||||
bool mEnableAttnLogitSoftcapping;
|
bool mEnableAttnLogitSoftcapping;
|
||||||
bool mReturnSoftmaxStats;
|
bool mReturnSoftmaxStats;
|
||||||
|
bool mEnableSkipSoftmax;
|
||||||
}} metaV2[] = {{
|
}} metaV2[] = {{
|
||||||
{metadata_v2}
|
{metadata_v2}
|
||||||
}};
|
}};
|
||||||
@ -3484,7 +3516,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
|
|||||||
bool mAlibiSupported;
|
bool mAlibiSupported;
|
||||||
bool mTiled;
|
bool mTiled;
|
||||||
bool mEnableAttnLogitSoftcapping;
|
bool mEnableAttnLogitSoftcapping;
|
||||||
bool mReturnSoftmaxStats;{launcher_line}
|
bool mReturnSoftmaxStats;
|
||||||
|
bool mEnableSkipSoftmax;{launcher_line}
|
||||||
}};
|
}};
|
||||||
|
|
||||||
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[];
|
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[];
|
||||||
@ -3580,7 +3613,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
|
|||||||
bool mAlibiSupported;
|
bool mAlibiSupported;
|
||||||
bool mTiled;
|
bool mTiled;
|
||||||
bool mEnableAttnLogitSoftcapping;
|
bool mEnableAttnLogitSoftcapping;
|
||||||
bool mReturnSoftmaxStats;{launcher_line}
|
bool mReturnSoftmaxStats;
|
||||||
|
bool mEnableSkipSoftmax;{launcher_line}
|
||||||
}};
|
}};
|
||||||
|
|
||||||
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[] = {{
|
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[] = {{
|
||||||
@ -3637,7 +3671,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_
|
|||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
|
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
|
||||||
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
|
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, false, nullptr},'
|
||||||
result = modify_kernel_line(result, target, new_line)
|
result = modify_kernel_line(result, target, new_line)
|
||||||
|
|
||||||
# make sure only one empty line at the end
|
# make sure only one empty line at the end
|
||||||
@ -3801,7 +3835,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
|
|
||||||
|
|
||||||
# Note this will be used in TRT-LLM.
|
# Note this will be used in TRT-LLM.
|
||||||
def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
def enumerate_hgmma_flash_warpspec_kernels(specs,
|
||||||
|
sm=90,
|
||||||
|
dtype='fp16',
|
||||||
|
enable_skip_softmax=False):
|
||||||
|
|
||||||
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
||||||
|
|
||||||
@ -3851,7 +3888,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||||
return_softmax_stats=return_softmax,
|
return_softmax_stats=return_softmax,
|
||||||
scheduling_mode=scheduling_mode,
|
scheduling_mode=scheduling_mode,
|
||||||
input_layout=input_layout))
|
input_layout=input_layout,
|
||||||
|
enable_skip_softmax=enable_skip_softmax))
|
||||||
|
|
||||||
specs.append(
|
specs.append(
|
||||||
kernel_spec(
|
kernel_spec(
|
||||||
@ -3883,7 +3921,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||||
return_softmax_stats=return_softmax,
|
return_softmax_stats=return_softmax,
|
||||||
scheduling_mode=scheduling_mode,
|
scheduling_mode=scheduling_mode,
|
||||||
input_layout=input_layout))
|
input_layout=input_layout,
|
||||||
|
enable_skip_softmax=enable_skip_softmax))
|
||||||
|
|
||||||
specs.append(
|
specs.append(
|
||||||
kernel_spec(
|
kernel_spec(
|
||||||
@ -3915,7 +3954,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||||
return_softmax_stats=return_softmax,
|
return_softmax_stats=return_softmax,
|
||||||
scheduling_mode=scheduling_mode,
|
scheduling_mode=scheduling_mode,
|
||||||
input_layout=input_layout))
|
input_layout=input_layout,
|
||||||
|
enable_skip_softmax=enable_skip_softmax))
|
||||||
'''
|
'''
|
||||||
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
|
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
|
||||||
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
|
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
|
||||||
@ -3967,7 +4007,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
|||||||
sm=90,
|
sm=90,
|
||||||
dtype='e4m3',
|
dtype='e4m3',
|
||||||
sage_block_sizes=None,
|
sage_block_sizes=None,
|
||||||
output_dtype=None):
|
output_dtype=None,
|
||||||
|
enable_skip_softmax=False):
|
||||||
|
|
||||||
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
||||||
|
|
||||||
@ -4021,7 +4062,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
|||||||
scheduling_mode=scheduling_mode,
|
scheduling_mode=scheduling_mode,
|
||||||
input_layout=input_layout,
|
input_layout=input_layout,
|
||||||
sage_block_sizes=sage_block_sizes,
|
sage_block_sizes=sage_block_sizes,
|
||||||
output_dtype=output_dtype))
|
output_dtype=output_dtype,
|
||||||
|
enable_skip_softmax=enable_skip_softmax))
|
||||||
|
|
||||||
# 64 < D <=128: KV_STEP = 128
|
# 64 < D <=128: KV_STEP = 128
|
||||||
specs.append(
|
specs.append(
|
||||||
@ -4056,7 +4098,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
|||||||
scheduling_mode=scheduling_mode,
|
scheduling_mode=scheduling_mode,
|
||||||
input_layout=input_layout,
|
input_layout=input_layout,
|
||||||
sage_block_sizes=sage_block_sizes,
|
sage_block_sizes=sage_block_sizes,
|
||||||
output_dtype=output_dtype))
|
output_dtype=output_dtype,
|
||||||
|
enable_skip_softmax=enable_skip_softmax))
|
||||||
|
|
||||||
# 128 < D <=256: KV_STEP = 128
|
# 128 < D <=256: KV_STEP = 128
|
||||||
specs.append(
|
specs.append(
|
||||||
@ -4092,7 +4135,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
|
|||||||
scheduling_mode=scheduling_mode,
|
scheduling_mode=scheduling_mode,
|
||||||
input_layout=input_layout,
|
input_layout=input_layout,
|
||||||
sage_block_sizes=sage_block_sizes,
|
sage_block_sizes=sage_block_sizes,
|
||||||
output_dtype=output_dtype))
|
output_dtype=output_dtype,
|
||||||
|
enable_skip_softmax=enable_skip_softmax))
|
||||||
|
|
||||||
if not skip_mla_combination:
|
if not skip_mla_combination:
|
||||||
# context MLA (192x128)
|
# context MLA (192x128)
|
||||||
@ -6374,13 +6418,21 @@ def enumerate_kernels():
|
|||||||
enumerate_igmma_kernels(specs, sm=90)
|
enumerate_igmma_kernels(specs, sm=90)
|
||||||
enumerate_qgmma_kernels(specs, sm=90)
|
enumerate_qgmma_kernels(specs, sm=90)
|
||||||
# need to add bf16 kernels if needed
|
# need to add bf16 kernels if needed
|
||||||
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16')
|
for enable_skip_softmax in [False, True]:
|
||||||
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='bf16')
|
if enable_skip_softmax and 'DISABLE_SKIP_SOFTMAX' in os.environ:
|
||||||
enumerate_qgmma_flash_warpspec_kernels(specs, sm=90, dtype='e4m3')
|
continue
|
||||||
enumerate_qgmma_flash_warpspec_kernels(specs,
|
enumerate_hgmma_flash_warpspec_kernels(
|
||||||
sm=90,
|
specs, sm=90, dtype='fp16', enable_skip_softmax=enable_skip_softmax)
|
||||||
dtype='e4m3',
|
enumerate_hgmma_flash_warpspec_kernels(
|
||||||
output_dtype="bf16")
|
specs, sm=90, dtype='bf16', enable_skip_softmax=enable_skip_softmax)
|
||||||
|
enumerate_qgmma_flash_warpspec_kernels(
|
||||||
|
specs, sm=90, dtype='e4m3', enable_skip_softmax=enable_skip_softmax)
|
||||||
|
enumerate_qgmma_flash_warpspec_kernels(
|
||||||
|
specs,
|
||||||
|
sm=90,
|
||||||
|
dtype='e4m3',
|
||||||
|
output_dtype="bf16",
|
||||||
|
enable_skip_softmax=enable_skip_softmax)
|
||||||
|
|
||||||
# For now SageAttention only needs BF16
|
# For now SageAttention only needs BF16
|
||||||
# block_size_q should be divisible by 64
|
# block_size_q should be divisible by 64
|
||||||
|
|||||||
@ -256,7 +256,8 @@ struct Compute
|
|||||||
actual_kv_seqlen, alibi_head_scale, \
|
actual_kv_seqlen, alibi_head_scale, \
|
||||||
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
|
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
|
||||||
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
|
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
|
||||||
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, kv_step_idx == kv_idx_end - 1);
|
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
|
||||||
|
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -360,6 +361,12 @@ struct Compute
|
|||||||
// Contiguous QKV FMHA assumes q, and kv have the same sequence length.
|
// Contiguous QKV FMHA assumes q, and kv have the same sequence length.
|
||||||
int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
|
int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
|
||||||
|
|
||||||
|
// Update threshold of Skip-Softmax
|
||||||
|
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
|
||||||
|
{
|
||||||
|
softmax.skip_softmax_threshold = params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate the alibi head_scaling_factor.
|
// Calculate the alibi head_scaling_factor.
|
||||||
float alibi_head_scale
|
float alibi_head_scale
|
||||||
= APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(head_info.bidh, params.alibi_params) : 0.f;
|
= APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(head_info.bidh, params.alibi_params) : 0.f;
|
||||||
@ -513,6 +520,13 @@ struct Compute
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
if (tidx == 0)
|
||||||
|
{
|
||||||
|
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
|
||||||
|
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -522,8 +536,15 @@ struct Compute
|
|||||||
Compute_tile_o& ctile_o, float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M],
|
Compute_tile_o& ctile_o, float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M],
|
||||||
int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
|
int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
|
||||||
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, Circular_buffer_kv_reader& cbr_v,
|
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, Circular_buffer_kv_reader& cbr_v,
|
||||||
OrderedMutexAccessor& mutex, bool complete = false)
|
OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, bool complete = false)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
// Skip-softmax vote initialization
|
||||||
|
if (tidx == 0)
|
||||||
|
{
|
||||||
|
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before voting.
|
||||||
|
*skip_softmax_vote = 1;
|
||||||
|
}
|
||||||
// load the scales of K/V from global memory
|
// load the scales of K/V from global memory
|
||||||
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
|
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
|
||||||
if constexpr (block_size > 0) \
|
if constexpr (block_size > 0) \
|
||||||
@ -557,6 +578,10 @@ struct Compute
|
|||||||
// Ctile_p is only used once by each n step.
|
// Ctile_p is only used once by each n step.
|
||||||
ctile_p.clear();
|
ctile_p.clear();
|
||||||
|
|
||||||
|
// If skip_softmax is enabled, make sure there is no racing between the initialization and writing of
|
||||||
|
// skip_softmax_vote.
|
||||||
|
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);
|
||||||
|
|
||||||
// BMM1 (Q x K').
|
// BMM1 (Q x K').
|
||||||
warpgroup_arrive();
|
warpgroup_arrive();
|
||||||
|
|
||||||
@ -626,8 +651,22 @@ struct Compute
|
|||||||
softmax.apply_alibi_and_mask<APPLY_MASK>(
|
softmax.apply_alibi_and_mask<APPLY_MASK>(
|
||||||
ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset);
|
ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset);
|
||||||
|
|
||||||
// Softmax Exp, max/sum, and update scales.
|
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
|
||||||
softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum);
|
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote))
|
||||||
|
{
|
||||||
|
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
|
||||||
|
{
|
||||||
|
// Notify another warpgroup to execute QGMMA.
|
||||||
|
mutex.named_bar_arrive();
|
||||||
|
}
|
||||||
|
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
|
||||||
|
int ready2 = cbr_v.peek();
|
||||||
|
if (!ready2)
|
||||||
|
{
|
||||||
|
cbr_v.wait();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// experiments show that here is the best place to load scales of V
|
// experiments show that here is the best place to load scales of V
|
||||||
float scales_v[SAGE_BLOCKS_PER_STEP_V];
|
float scales_v[SAGE_BLOCKS_PER_STEP_V];
|
||||||
|
|||||||
@ -17,6 +17,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "fmha/hopper/arrive_wait.h"
|
||||||
|
|
||||||
#include <fmha/softmax.h>
|
#include <fmha/softmax.h>
|
||||||
#include <fmha/traits.h>
|
#include <fmha/traits.h>
|
||||||
#include <fmha/utils.h>
|
#include <fmha/utils.h>
|
||||||
@ -104,6 +106,12 @@ struct Softmax_base
|
|||||||
CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK
|
CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// There are 2 warpgroups so 0x3 and 0x4 are used
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID
|
||||||
|
};
|
||||||
|
|
||||||
// Ctor.
|
// Ctor.
|
||||||
template <typename Params>
|
template <typename Params>
|
||||||
inline __device__ Softmax_base(Params params, int tidx)
|
inline __device__ Softmax_base(Params params, int tidx)
|
||||||
@ -114,6 +122,11 @@ struct Softmax_base
|
|||||||
, log2_chunked_attention_size_(params.log2_chunked_attention_size)
|
, log2_chunked_attention_size_(params.log2_chunked_attention_size)
|
||||||
, packed_mask_ptr_{reinterpret_cast<uint32_t*>(params.packed_mask_ptr)}
|
, packed_mask_ptr_{reinterpret_cast<uint32_t*>(params.packed_mask_ptr)}
|
||||||
, params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes}
|
, params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes}
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
, total_blocks(0)
|
||||||
|
, skipped_blocks(0)
|
||||||
|
#endif
|
||||||
|
, skip_softmax_threshold(0)
|
||||||
{
|
{
|
||||||
|
|
||||||
int warp = tidx / 32;
|
int warp = tidx / 32;
|
||||||
@ -330,24 +343,22 @@ struct Softmax_base
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Calculate max/sum, and update flash-attention scales.
|
// Calculate max/sum, and update flash-attention scales.
|
||||||
|
// Returns false if skipped due to skip-softmax attention feature.
|
||||||
template <bool IS_FIRST_COL>
|
template <bool IS_FIRST_COL>
|
||||||
inline __device__ void compute_and_update_scale(
|
inline __device__ bool compute_and_update_scale(
|
||||||
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
|
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
|
||||||
{
|
{
|
||||||
float const scale = reinterpret_cast<float const&>(scale_bmm1_);
|
float const scale = reinterpret_cast<float const&>(scale_bmm1_);
|
||||||
|
|
||||||
|
// whether this warpgroup skips the softmax
|
||||||
|
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
|
||||||
|
bool skip = may_skip;
|
||||||
|
|
||||||
// Row-wise max of current tile.
|
// Row-wise max of current tile.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
|
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
|
||||||
{
|
{
|
||||||
if (IS_FIRST_COL)
|
local_max_[mi] = elt_[mi][0];
|
||||||
{
|
|
||||||
local_max_[mi] = elt_[mi][0];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
|
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
|
||||||
{
|
{
|
||||||
@ -355,6 +366,56 @@ struct Softmax_base
|
|||||||
}
|
}
|
||||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
|
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
|
||||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
|
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
|
||||||
|
|
||||||
|
if constexpr (may_skip)
|
||||||
|
{
|
||||||
|
// AND(&) the CORES_M results, then `skip` means whether to skip
|
||||||
|
// the CORES_M(=2) rows
|
||||||
|
if constexpr (!EXP2F_OPTIMIZATION)
|
||||||
|
{
|
||||||
|
skip &= expf(local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!IS_FIRST_COL)
|
||||||
|
{
|
||||||
|
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
|
||||||
|
{
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
total_blocks++;
|
||||||
|
#endif
|
||||||
|
if constexpr (may_skip)
|
||||||
|
{
|
||||||
|
|
||||||
|
// AND(&) the results together in a warp, then `skip` means whether to skip
|
||||||
|
// all the 16 rows managed by this warp.
|
||||||
|
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
|
||||||
|
// instead of 0xffffffff. But the perf is the same.
|
||||||
|
skip = __all_sync(0xffffffff, skip);
|
||||||
|
if (threadIdx.x % 32 == 0)
|
||||||
|
{
|
||||||
|
// The leader of each warp votes.
|
||||||
|
atomicAnd(skip_softmax_vote, uint32_t(skip));
|
||||||
|
}
|
||||||
|
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
|
||||||
|
named_barrier_wait(SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
|
||||||
|
skip = *((uint32_t volatile*) skip_softmax_vote);
|
||||||
|
if (skip)
|
||||||
|
{
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
skipped_blocks++;
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Softmax Exp.
|
// Softmax Exp.
|
||||||
@ -436,6 +497,7 @@ struct Softmax_base
|
|||||||
global_max[mi] = max_new;
|
global_max[mi] = max_new;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update flash attention scales and pack elements for BMM2.
|
// Update flash attention scales and pack elements for BMM2.
|
||||||
@ -513,6 +575,13 @@ struct Softmax_base
|
|||||||
float correction_[Mma_tile_p::CORES_M];
|
float correction_[Mma_tile_p::CORES_M];
|
||||||
// The packed mask.
|
// The packed mask.
|
||||||
uint4 packed_mask_;
|
uint4 packed_mask_;
|
||||||
|
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
|
||||||
|
float skip_softmax_threshold;
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
// Statistics of skip-softmax
|
||||||
|
uint32_t total_blocks;
|
||||||
|
uint32_t skipped_blocks;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -868,9 +937,10 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Calculate max/sum, and update flash-attention scales.
|
// Calculate max/sum, and update flash-attention scales.
|
||||||
|
// Returns false if skipped due to skip-softmax attention feature.
|
||||||
template <bool IS_FIRST_COL>
|
template <bool IS_FIRST_COL>
|
||||||
inline __device__ void compute_and_update_scale(
|
inline __device__ bool compute_and_update_scale(
|
||||||
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
|
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
|
||||||
{
|
{
|
||||||
float const scale = reinterpret_cast<float const&>(this->scale_bmm1_);
|
float const scale = reinterpret_cast<float const&>(this->scale_bmm1_);
|
||||||
float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_;
|
float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_;
|
||||||
@ -878,18 +948,15 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
|||||||
float(&correction_)[Mma_tile_p::CORES_M] = this->correction_;
|
float(&correction_)[Mma_tile_p::CORES_M] = this->correction_;
|
||||||
float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_;
|
float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_;
|
||||||
|
|
||||||
|
// whether this warpgroup skips the softmax
|
||||||
|
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
|
||||||
|
bool skip = may_skip;
|
||||||
|
|
||||||
// Row-wise max of current tile.
|
// Row-wise max of current tile.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
|
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
|
||||||
{
|
{
|
||||||
if (IS_FIRST_COL)
|
local_max_[mi] = elt_[mi][0];
|
||||||
{
|
|
||||||
local_max_[mi] = elt_[mi][0];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
|
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
|
||||||
{
|
{
|
||||||
@ -897,6 +964,56 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
|||||||
}
|
}
|
||||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
|
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
|
||||||
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
|
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
|
||||||
|
// AND(&) the CORES_M results, then `skip` means whether to skip
|
||||||
|
// the CORES_M(=2) rows
|
||||||
|
if constexpr (may_skip)
|
||||||
|
{
|
||||||
|
// AND(&) the CORES_M results, then `skip` means whether to skip
|
||||||
|
// the CORES_M(=2) rows
|
||||||
|
if constexpr (!EXP2F_OPTIMIZATION)
|
||||||
|
{
|
||||||
|
skip &= expf(local_max_[mi] - global_max[mi]) < this->skip_softmax_threshold;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < this->skip_softmax_threshold;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!IS_FIRST_COL)
|
||||||
|
{
|
||||||
|
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
|
||||||
|
{
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
this->total_blocks++;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if constexpr (may_skip)
|
||||||
|
{
|
||||||
|
// AND(&) the results together in a warp, then `skip` means whether to skip
|
||||||
|
// all the 16 rows managed by this warp.
|
||||||
|
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
|
||||||
|
// instead of 0xffffffff. But the perf is the same.
|
||||||
|
skip = __all_sync(0xffffffff, skip);
|
||||||
|
if (threadIdx.x % 32 == 0)
|
||||||
|
{
|
||||||
|
// The leader of each warp votes.
|
||||||
|
atomicAnd(skip_softmax_vote, uint32_t(skip));
|
||||||
|
}
|
||||||
|
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
|
||||||
|
named_barrier_wait(Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
|
||||||
|
skip = *((uint32_t volatile*) skip_softmax_vote);
|
||||||
|
if (skip)
|
||||||
|
{
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
this->skipped_blocks++;
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Softmax Exp.
|
// Softmax Exp.
|
||||||
@ -987,6 +1104,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
|
|||||||
global_max[mi] = max_new;
|
global_max[mi] = max_new;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update flash attention scales and pack elements for BMM2.
|
// Update flash attention scales and pack elements for BMM2.
|
||||||
|
|||||||
@ -71,6 +71,8 @@ template <
|
|||||||
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
|
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
|
||||||
// Save softmax stats ?
|
// Save softmax stats ?
|
||||||
bool RETURN_SOFTMAX_STATS_ = false,
|
bool RETURN_SOFTMAX_STATS_ = false,
|
||||||
|
// Enable skip softmax attention feature
|
||||||
|
bool ENABLE_SKIP_SOFTMAX_ = false,
|
||||||
// The output type (only used by fp8 kernels).
|
// The output type (only used by fp8 kernels).
|
||||||
typename OutputType = typename Instruction_traits<STEP_Q_, STEP_KV_, 0, false, false>::A_type,
|
typename OutputType = typename Instruction_traits<STEP_Q_, STEP_KV_, 0, false, false>::A_type,
|
||||||
// The sage attention block size for Q, K and V
|
// The sage attention block size for Q, K and V
|
||||||
@ -290,6 +292,12 @@ struct Kernel_traits
|
|||||||
USE_CUSTOM_MASK = ATTENTION_MASK_TYPE_ == 3
|
USE_CUSTOM_MASK = ATTENTION_MASK_TYPE_ == 3
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Are we enabling skip softmax attention feature?
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_
|
||||||
|
};
|
||||||
|
|
||||||
static_assert(!USE_CUSTOM_MASK || STEP_KV == 64 || STEP_KV == 128 || STEP_KV == 256, "Not implemented!");
|
static_assert(!USE_CUSTOM_MASK || STEP_KV == 64 || STEP_KV == 128 || STEP_KV == 256, "Not implemented!");
|
||||||
|
|
||||||
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
|
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
|
||||||
@ -384,6 +392,8 @@ struct Kernel_traits
|
|||||||
// Named barrier ids
|
// Named barrier ids
|
||||||
static constexpr int DMA_SYNC_BARRIER_ID = 0x1;
|
static constexpr int DMA_SYNC_BARRIER_ID = 0x1;
|
||||||
static constexpr int MMA_SYNC_BARRIER_ID = 0x2;
|
static constexpr int MMA_SYNC_BARRIER_ID = 0x2;
|
||||||
|
// There are 2 warpgroups so 0x3 and 0x4 are used for skip-softmax
|
||||||
|
static constexpr int SKIP_SOFTMAX_BARRIER_ID = 0x3;
|
||||||
|
|
||||||
// How many threads get involved in the dma group.
|
// How many threads get involved in the dma group.
|
||||||
enum
|
enum
|
||||||
@ -518,6 +528,10 @@ struct Kernel_traits
|
|||||||
// Mutex
|
// Mutex
|
||||||
OrderedMutex compute_mutex;
|
OrderedMutex compute_mutex;
|
||||||
|
|
||||||
|
// 4 warps in a warpgroup vote to an atomic variable in shared memory
|
||||||
|
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive KV_STEPS.
|
||||||
|
uint32_t skip_softmax_votes[2][NUM_COMPUTE_GROUPS];
|
||||||
|
|
||||||
inline __device__ void init(int tid0)
|
inline __device__ void init(int tid0)
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -580,6 +594,8 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
|
|||||||
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
|
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
|
||||||
// Save softmax stats ?
|
// Save softmax stats ?
|
||||||
bool RETURN_SOFTMAX_STATS_ = false,
|
bool RETURN_SOFTMAX_STATS_ = false,
|
||||||
|
// Enable skip softmax attention feature
|
||||||
|
bool ENABLE_SKIP_SOFTMAX_ = false,
|
||||||
// The output type (only used by fp8 kernels).
|
// The output type (only used by fp8 kernels).
|
||||||
typename OutputType = e4m3_t,
|
typename OutputType = e4m3_t,
|
||||||
// The sage attention block size for Q, K and V
|
// The sage attention block size for Q, K and V
|
||||||
@ -588,14 +604,15 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
|
|||||||
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
||||||
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
|
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
|
||||||
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
|
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
|
||||||
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>
|
RETURN_SOFTMAX_STATS_, ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_,
|
||||||
|
SAGE_BLOCK_SIZE_V_>
|
||||||
{
|
{
|
||||||
|
|
||||||
// Base class.
|
// Base class.
|
||||||
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
||||||
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
|
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
|
||||||
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_,
|
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_,
|
||||||
OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
|
ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
|
||||||
|
|
||||||
enum
|
enum
|
||||||
{
|
{
|
||||||
@ -693,6 +710,10 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
|
|||||||
// Mutex
|
// Mutex
|
||||||
OrderedMutex compute_mutex;
|
OrderedMutex compute_mutex;
|
||||||
|
|
||||||
|
// 4 warps in a warpgroup vote to an atomic variable in shared memory
|
||||||
|
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive STEP_KVs.
|
||||||
|
uint32_t skip_softmax_votes[2][Base::NUM_COMPUTE_GROUPS];
|
||||||
|
|
||||||
inline __device__ void init(int tid0)
|
inline __device__ void init(int tid0)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
|||||||
@ -276,7 +276,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
|||||||
// scale factors
|
// scale factors
|
||||||
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
|
||||||
// flags
|
// flags
|
||||||
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi)
|
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi,
|
||||||
|
float const skip_softmax_threshold_scale_factor)
|
||||||
{
|
{
|
||||||
|
|
||||||
memset(¶ms, 0, sizeof(params));
|
memset(¶ms, 0, sizeof(params));
|
||||||
@ -421,6 +422,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
|
|||||||
params.enable_i2f_trick
|
params.enable_i2f_trick
|
||||||
= -double(1 << 22) * double(scale_bmm2) <= -128.f && double(1 << 22) * double(scale_bmm2) >= 127.f;
|
= -double(1 << 22) * double(scale_bmm2) <= -128.f && double(1 << 22) * double(scale_bmm2) >= 127.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip-softmax attention
|
||||||
|
params.skip_softmax_threshold_scale_factor = skip_softmax_threshold_scale_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -429,7 +433,7 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
|
|||||||
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
|
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
|
||||||
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
|
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
|
||||||
bool const force_non_flash_attention, bool const force_non_warp_specialization,
|
bool const force_non_flash_attention, bool const force_non_warp_specialization,
|
||||||
bool const force_non_granular_tiling, bool const force_fp32_acc,
|
bool const force_non_granular_tiling, bool const force_fp32_acc, float const skip_softmax_threshold_scale_factor,
|
||||||
// device props
|
// device props
|
||||||
const cudaDeviceProp props)
|
const cudaDeviceProp props)
|
||||||
{
|
{
|
||||||
@ -470,6 +474,9 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
|
|||||||
"are not supported on Ada currently.\n");
|
"are not supported on Ada currently.\n");
|
||||||
launch_params.use_granular_tiling = false;
|
launch_params.use_granular_tiling = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enable skip softmax attention or not.
|
||||||
|
launch_params.enable_skip_softmax = skip_softmax_threshold_scale_factor > 0.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -589,6 +596,9 @@ int main(int argc, char** argv)
|
|||||||
// Use attention sinks (added to the denominator of softmax)
|
// Use attention sinks (added to the denominator of softmax)
|
||||||
bool use_attention_sinks = false;
|
bool use_attention_sinks = false;
|
||||||
|
|
||||||
|
// Skip-softmax attention
|
||||||
|
float skip_softmax_threshold_scale_factor = 0;
|
||||||
|
|
||||||
// Read the parameters from the command-line.
|
// Read the parameters from the command-line.
|
||||||
for (int ii = 1; ii < argc; ++ii)
|
for (int ii = 1; ii < argc; ++ii)
|
||||||
{
|
{
|
||||||
@ -885,6 +895,10 @@ int main(int argc, char** argv)
|
|||||||
{
|
{
|
||||||
use_attention_sinks = true;
|
use_attention_sinks = true;
|
||||||
}
|
}
|
||||||
|
else if (!strcmp(argv[ii], "-skip-softmax-threshold-scale-factor") && ++ii < argc)
|
||||||
|
{
|
||||||
|
skip_softmax_threshold_scale_factor = strtof(argv[ii], nullptr);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
|
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
|
||||||
@ -1057,7 +1071,7 @@ int main(int argc, char** argv)
|
|||||||
Launch_params launch_params;
|
Launch_params launch_params;
|
||||||
determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, interleaved,
|
determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, interleaved,
|
||||||
ignore_b1opt, force_unroll, use_tma, force_non_flash_attention, force_non_warp_specialization,
|
ignore_b1opt, force_unroll, use_tma, force_non_flash_attention, force_non_warp_specialization,
|
||||||
force_non_granular_tiling, force_fp32_acc, props);
|
force_non_granular_tiling, force_fp32_acc, skip_softmax_threshold_scale_factor, props);
|
||||||
|
|
||||||
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
|
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
|
||||||
const size_t qkv_size = s * b * h * (2 * d + dv);
|
const size_t qkv_size = s * b * h * (2 * d + dv);
|
||||||
@ -1713,7 +1727,13 @@ int main(int argc, char** argv)
|
|||||||
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
|
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
|
||||||
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
|
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
|
||||||
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
|
||||||
use_int8_scale_max, interleaved, is_s_padded, has_alibi);
|
use_int8_scale_max, interleaved, is_s_padded, has_alibi, skip_softmax_threshold_scale_factor);
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
FMHA_CHECK_CUDA(cudaMalloc(¶ms_v2.skip_softmax_total_blocks, sizeof(uint32_t)));
|
||||||
|
FMHA_CHECK_CUDA(cudaMalloc(¶ms_v2.skip_softmax_skipped_blocks, sizeof(uint32_t)));
|
||||||
|
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_total_blocks, 0, sizeof(uint32_t)));
|
||||||
|
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_skipped_blocks, 0, sizeof(uint32_t)));
|
||||||
|
#endif
|
||||||
|
|
||||||
// total number of tokens is needed to set TMA desc on the host.
|
// total number of tokens is needed to set TMA desc on the host.
|
||||||
launch_params.total_q_seqlen = q_seqlens[b];
|
launch_params.total_q_seqlen = q_seqlens[b];
|
||||||
@ -2101,6 +2121,18 @@ int main(int argc, char** argv)
|
|||||||
non_fused_elapsed / fused_elapsed, total_flops / (fused_elapsed / float(runs) / 1e-9),
|
non_fused_elapsed / fused_elapsed, total_flops / (fused_elapsed / float(runs) / 1e-9),
|
||||||
total_bytes / (fused_elapsed / float(runs) / 1e-6));
|
total_bytes / (fused_elapsed / float(runs) / 1e-6));
|
||||||
}
|
}
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
if (skip_softmax_threshold_scale_factor > 0)
|
||||||
|
{
|
||||||
|
uint32_t total_blocks, skipped_blocks;
|
||||||
|
FMHA_CHECK_CUDA(
|
||||||
|
cudaMemcpy(&total_blocks, params_v2.skip_softmax_total_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
|
||||||
|
FMHA_CHECK_CUDA(cudaMemcpy(
|
||||||
|
&skipped_blocks, params_v2.skip_softmax_skipped_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
|
||||||
|
printf("Skip-Softmax .: %u / %u = %.2f%%\n", skipped_blocks, total_blocks,
|
||||||
|
total_blocks ? 100.f * skipped_blocks / total_blocks : 0.f);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#if defined(DEBUG_HAS_PRINT_BUFFER)
|
#if defined(DEBUG_HAS_PRINT_BUFFER)
|
||||||
FMHA_CHECK_CUDA(cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32));
|
FMHA_CHECK_CUDA(cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32));
|
||||||
|
|
||||||
@ -2141,6 +2173,11 @@ int main(int argc, char** argv)
|
|||||||
FMHA_CHECK_CUDA(cudaFree(kv_cache_block_offsets_d));
|
FMHA_CHECK_CUDA(cudaFree(kv_cache_block_offsets_d));
|
||||||
FMHA_CHECK_CUDA(cudaFree(contiguous_kv_d));
|
FMHA_CHECK_CUDA(cudaFree(contiguous_kv_d));
|
||||||
FMHA_CHECK_CUDA(cudaFree(softmax_stats_d));
|
FMHA_CHECK_CUDA(cudaFree(softmax_stats_d));
|
||||||
|
FMHA_CHECK_CUDA(cudaFree(attention_sinks_d));
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_total_blocks));
|
||||||
|
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_skipped_blocks));
|
||||||
|
#endif
|
||||||
|
|
||||||
free(qkv_h);
|
free(qkv_h);
|
||||||
free(mask_h);
|
free(mask_h);
|
||||||
|
|||||||
@ -283,6 +283,16 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
|
|||||||
float* scales;
|
float* scales;
|
||||||
} q, k, v;
|
} q, k, v;
|
||||||
} sage;
|
} sage;
|
||||||
|
|
||||||
|
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
|
||||||
|
// A positive value means skip-softmax is enabled.
|
||||||
|
float skip_softmax_threshold_scale_factor = 0;
|
||||||
|
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
// Statistics of skip-softmax, pointers of device memory for output
|
||||||
|
uint32_t* skip_softmax_total_blocks;
|
||||||
|
uint32_t* skip_softmax_skipped_blocks;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@ -322,6 +332,8 @@ struct Fused_multihead_attention_launch_params
|
|||||||
// harward properties to determine how to launch blocks
|
// harward properties to determine how to launch blocks
|
||||||
int multi_processor_count = 0;
|
int multi_processor_count = 0;
|
||||||
int device_l2_cache_size = 0;
|
int device_l2_cache_size = 0;
|
||||||
|
// skip softmax attention
|
||||||
|
bool enable_skip_softmax = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -177,4 +177,13 @@ struct Fused_multihead_attention_params_v2
|
|||||||
float* scales;
|
float* scales;
|
||||||
} q, k, v;
|
} q, k, v;
|
||||||
} sage;
|
} sage;
|
||||||
|
|
||||||
|
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
|
||||||
|
// A positive value means skip-softmax is enabled.
|
||||||
|
float skip_softmax_threshold_scale_factor = 0;
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
// Statistics of skip-softmax, pointers of device memory for output
|
||||||
|
uint32_t* skip_softmax_total_blocks;
|
||||||
|
uint32_t* skip_softmax_skipped_blocks;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|||||||
@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
|
|||||||
#define SLIDING_WINDOW 0
|
#define SLIDING_WINDOW 0
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef SKIP_SOFTMAX_ATTN
|
||||||
|
#define SKIP_SOFTMAX_ATTN 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||||
|
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
|
||||||
|
#endif
|
||||||
|
|
||||||
// 0 - no PDL
|
// 0 - no PDL
|
||||||
// 1 - naive PDL
|
// 1 - naive PDL
|
||||||
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
|
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
|
||||||
|
|||||||
@ -89,7 +89,8 @@ cpp_file_prefix_text = R"""/*
|
|||||||
|
|
||||||
#include "tensorrt_llm/common/config.h"
|
#include "tensorrt_llm/common/config.h"
|
||||||
|
|
||||||
TRTLLM_NAMESPACE_BEGIN
|
namespace tensorrt_llm
|
||||||
|
{
|
||||||
namespace kernels
|
namespace kernels
|
||||||
{
|
{
|
||||||
// clang-format off
|
// clang-format off
|
||||||
@ -98,7 +99,7 @@ namespace kernels
|
|||||||
cpp_file_suffex_text = R"""
|
cpp_file_suffex_text = R"""
|
||||||
// clang-format on
|
// clang-format on
|
||||||
} // namespace kernels
|
} // namespace kernels
|
||||||
TRTLLM_NAMESPACE_END
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cubin_meta_info_struct_prefix_text = R"""
|
cubin_meta_info_struct_prefix_text = R"""
|
||||||
@ -438,8 +439,9 @@ if __name__ == "__main__":
|
|||||||
CompileMacroOption('HEAD_ELEMS', 'd', [128]),
|
CompileMacroOption('HEAD_ELEMS', 'd', [128]),
|
||||||
CompileMacroOption('BEAM_WIDTH', 'beam', [1]),
|
CompileMacroOption('BEAM_WIDTH', 'beam', [1]),
|
||||||
CompileMacroOption('CACHE_ELEM_ENUM', 'kvt', [0, 1, 2]),
|
CompileMacroOption('CACHE_ELEM_ENUM', 'kvt', [0, 1, 2]),
|
||||||
CompileMacroOption('TOKENS_PER_PAGE', 'pagedKV',
|
CompileMacroOption(
|
||||||
[0, 64, 128]), # 0 denotes contiguous kv cache.
|
'TOKENS_PER_PAGE', 'pagedKV',
|
||||||
|
[0, 32, 64, 128]), # 0 denotes contiguous kv cache.
|
||||||
CompileMacroOption('HEAD_GRP_SIZE', 'nqpkv', [0]),
|
CompileMacroOption('HEAD_GRP_SIZE', 'nqpkv', [0]),
|
||||||
CompileMacroOption('M_TILESIZE', 'm', [16, 32]),
|
CompileMacroOption('M_TILESIZE', 'm', [16, 32]),
|
||||||
]]
|
]]
|
||||||
|
|||||||
@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
|
|||||||
asm volatile("trap;\n");
|
asm volatile("trap;\n");
|
||||||
return 0;
|
return 0;
|
||||||
}();
|
}();
|
||||||
|
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
|
||||||
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
|
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
|
||||||
return MatDesc{
|
return MatDesc{
|
||||||
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
|
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
|
||||||
|
|||||||
@ -465,13 +465,10 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
|
|||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
#define MMAS_N_PER_MASK 2
|
#define MMAS_N_PER_MASK 2
|
||||||
|
|
||||||
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
|
|
||||||
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
|
|
||||||
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
||||||
,
|
__device__ inline void applyMaskFromInputSlidingAndSpecDec(Warp const& warp, WarpAcc& acc, MaskType const* mask,
|
||||||
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
|
uint32_t rowOffset, uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize,
|
||||||
#endif
|
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg)
|
||||||
)
|
|
||||||
{
|
{
|
||||||
uint32_t const idxInQuad = laneId() % 4;
|
uint32_t const idxInQuad = laneId() % 4;
|
||||||
uint32_t const idxQuad = laneId() / 4;
|
uint32_t const idxQuad = laneId() / 4;
|
||||||
@ -479,7 +476,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
|
|||||||
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
|
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
|
||||||
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
|
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
|
||||||
constexpr uint64_t fullMask = ~uint64_t{0};
|
constexpr uint64_t fullMask = ~uint64_t{0};
|
||||||
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
|
||||||
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
|
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
|
||||||
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
|
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
|
||||||
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
|
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
|
||||||
@ -487,11 +483,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
|
|||||||
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
|
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
|
||||||
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
||||||
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
|
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
|
||||||
#else
|
|
||||||
constexpr bool ctaNeedBegMask = false;
|
|
||||||
bool const ctaNeedSpecDecMask = true;
|
|
||||||
int32_t const tok0NbMaskOut = -2147483648;
|
|
||||||
#endif
|
|
||||||
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
|
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
|
||||||
|
|
||||||
if (!needMask)
|
if (!needMask)
|
||||||
@ -559,6 +550,61 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
|
||||||
|
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
|
||||||
|
{
|
||||||
|
uint32_t const idxInQuad = laneId() % 4;
|
||||||
|
uint32_t const idxQuad = laneId() / 4;
|
||||||
|
// Packed mask is aligned with 32 bits (2 uint16_t).
|
||||||
|
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
|
||||||
|
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t m = 0; m < acc.rows; m++)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t i = 0; i < InstAcc::rows; i++)
|
||||||
|
{
|
||||||
|
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
|
||||||
|
{
|
||||||
|
uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad;
|
||||||
|
uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1) + InstAcc::cols - 1;
|
||||||
|
uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols
|
||||||
|
? 0u
|
||||||
|
: min(firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
|
||||||
|
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
|
||||||
|
? 0u
|
||||||
|
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
|
||||||
|
uint32_t packedMask = 0u;
|
||||||
|
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
|
||||||
|
reinterpret_cast<uint16_t*>(&packedMask)[0]
|
||||||
|
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
|
||||||
|
reinterpret_cast<uint16_t*>(&packedMask)[1]
|
||||||
|
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < InstAcc::cols; j++)
|
||||||
|
{
|
||||||
|
uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj);
|
||||||
|
uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j;
|
||||||
|
// bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col +
|
||||||
|
// qSeqLen - nbValidCols)];
|
||||||
|
bool const maskFlag = col + actualQSeqLen < nbValidCols
|
||||||
|
? true
|
||||||
|
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
|
||||||
|
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
__device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc)
|
__device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc)
|
||||||
{
|
{
|
||||||
QuadRegRowMax rowMax = rowMaxHint;
|
QuadRegRowMax rowMax = rowMaxHint;
|
||||||
@ -1655,7 +1701,7 @@ CUBIN_EXPORT __global__
|
|||||||
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
|
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
|
||||||
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
|
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
|
||||||
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
|
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
|
||||||
|
bool const rtIsReallySliding = (cacheSeqLen + actualQSeqLen > slidingWinSize);
|
||||||
#elif SLIDING_WINDOW
|
#elif SLIDING_WINDOW
|
||||||
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
|
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
|
||||||
assert(!SPEC_DEC || !rtIsReallySliding);
|
assert(!SPEC_DEC || !rtIsReallySliding);
|
||||||
@ -1673,7 +1719,8 @@ CUBIN_EXPORT __global__
|
|||||||
|
|
||||||
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
|
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
|
||||||
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
|
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
|
||||||
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
|
uint32_t const nbSeqItersWithoutMask
|
||||||
|
= rtIsReallySliding ? nbSkipLeadingTiles : (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
||||||
#elif SPEC_DEC
|
#elif SPEC_DEC
|
||||||
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
||||||
#endif
|
#endif
|
||||||
@ -1960,12 +2007,18 @@ CUBIN_EXPORT __global__
|
|||||||
if (seqIter >= nbSeqItersWithoutMask)
|
if (seqIter >= nbSeqItersWithoutMask)
|
||||||
{
|
{
|
||||||
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
|
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
|
||||||
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
|
|
||||||
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
||||||
,
|
if (rtIsReallySliding)
|
||||||
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
|
{
|
||||||
|
applyMaskFromInputSlidingAndSpecDec(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen,
|
||||||
|
actualQSeqLen, headGrpSize, tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg);
|
||||||
|
}
|
||||||
|
else
|
||||||
#endif
|
#endif
|
||||||
);
|
{
|
||||||
|
applyMaskFromInput(
|
||||||
|
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
|
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
|
||||||
@ -2734,6 +2787,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef GENERATE_CUBIN
|
#ifndef GENERATE_CUBIN
|
||||||
|
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
|
||||||
|
{
|
||||||
|
if (!allowMultiBlockMode)
|
||||||
|
{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||||
|
if (env != nullptr)
|
||||||
|
{
|
||||||
|
int32_t const val = std::stoi(env);
|
||||||
|
if (val > 0)
|
||||||
|
{
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::min<uint32_t>(
|
||||||
|
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
|
||||||
|
}
|
||||||
|
|
||||||
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||||
#if SLIDING_WINDOW
|
#if SLIDING_WINDOW
|
||||||
uint32_t slidingWinSize,
|
uint32_t slidingWinSize,
|
||||||
@ -2771,6 +2843,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
|||||||
// int8/fp8 KV cache.
|
// int8/fp8 KV cache.
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
SpecDecParams const& specDecParams,
|
SpecDecParams const& specDecParams,
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
|
||||||
|
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
@ -2793,24 +2872,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
|||||||
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
||||||
|
|
||||||
// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
|
// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
|
||||||
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
|
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
|
||||||
{
|
|
||||||
if (!allowMultiBlockMode)
|
|
||||||
{
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
|
||||||
if (env != nullptr)
|
|
||||||
{
|
|
||||||
int32_t const val = std::stoi(env);
|
|
||||||
if (val > 0)
|
|
||||||
{
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::min<uint32_t>(
|
|
||||||
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
|
|
||||||
}();
|
|
||||||
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
|
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);
|
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);
|
||||||
|
|||||||
@ -90,6 +90,9 @@ struct BeamSearchParams
|
|||||||
// match trt-llm API.
|
// match trt-llm API.
|
||||||
};
|
};
|
||||||
|
|
||||||
|
uint32_t computeNbSubSeqPerSeqMHA(
|
||||||
|
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
|
||||||
|
|
||||||
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
||||||
#if SLIDING_WINDOW
|
#if SLIDING_WINDOW
|
||||||
uint32_t slidingWinSize,
|
uint32_t slidingWinSize,
|
||||||
@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
|||||||
// int8/fp8 KV cache.
|
// int8/fp8 KV cache.
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
SpecDecParams const& specDecParams,
|
SpecDecParams const& specDecParams,
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
float const skipSoftmaxThresholdScaleFactor,
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
||||||
|
|
||||||
|
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
|
||||||
|
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
|
||||||
|
|
||||||
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||||
#if SLIDING_WINDOW
|
#if SLIDING_WINDOW
|
||||||
uint32_t slidingWinSize,
|
uint32_t slidingWinSize,
|
||||||
@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
|||||||
// int8/fp8 KV cache.
|
// int8/fp8 KV cache.
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
SpecDecParams const& specDecParams,
|
SpecDecParams const& specDecParams,
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
float const skipSoftmaxThresholdScaleFactor,
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
||||||
|
|
||||||
|
|||||||
@ -49,6 +49,10 @@ static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is to
|
|||||||
#define SWAP_AB (!SPEC_DEC)
|
#define SWAP_AB (!SPEC_DEC)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
static_assert(SWAP_AB && USE_PAGED_KV_CACHE && !SPEC_DEC && BEAM_WIDTH == 1, "SKIP_SOFTMAX_ATTN is not supported.");
|
||||||
|
#endif
|
||||||
|
|
||||||
#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT)
|
#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT)
|
||||||
|
|
||||||
inline constexpr bool swapAB = SWAP_AB;
|
inline constexpr bool swapAB = SWAP_AB;
|
||||||
@ -138,26 +142,38 @@ using PaddedOutHead = PaddedInputHead;
|
|||||||
|
|
||||||
struct alignas(128) SharedMem
|
struct alignas(128) SharedMem
|
||||||
{
|
{
|
||||||
|
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
|
||||||
using KBuffer = Array2D<LdGrain, gemm0CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes)>;
|
using KBuffer = Array2D<LdGrain, gemm0CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes)>;
|
||||||
static constexpr uint32_t nbKBuf = 2;
|
|
||||||
KBuffer k[nbKBuf]; // as is loaded from global mem.
|
|
||||||
using XBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerXPart>, nbXParts>;
|
using XBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerXPart>, nbXParts>;
|
||||||
static constexpr uint32_t nbXBuf
|
|
||||||
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
|
||||||
using VBuffer = Vec<Array2D<LdGrain, gemm1CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes),
|
using VBuffer = Vec<Array2D<LdGrain, gemm1CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes),
|
||||||
sizeof(XBuffer) % (cacheHeadPartBytes * 8) == 0>,
|
sizeof(XBuffer) % (cacheHeadPartBytes * 8) == 0>,
|
||||||
cacheHeadNbParts>;
|
cacheHeadNbParts>;
|
||||||
#if !SWAP_AB
|
#if !SWAP_AB
|
||||||
using VTBuffer = Array2D<LdGrain, headElems, exactDiv(gemm1CtaTileNbTokens, cacheElemsPerGrain), true>;
|
using VTBuffer = Array2D<LdGrain, headElems, exactDiv(gemm1CtaTileNbTokens, cacheElemsPerGrain), true>;
|
||||||
#endif
|
#endif
|
||||||
static constexpr uint32_t nbVBuf = 2;
|
|
||||||
#if CACHE_ELEM_ENUM == 0
|
#if CACHE_ELEM_ENUM == 0
|
||||||
using OutSwizzleBuf = Array2D<LdGrain, ctaNbQHeads, grainsPerPaddedInputHead>;
|
using OutSwizzleBuf = Array2D<LdGrain, ctaNbQHeads, grainsPerPaddedInputHead>;
|
||||||
#elif CACHE_ELEM_ENUM == 2
|
#elif CACHE_ELEM_ENUM == 2
|
||||||
using OutSwizzleBuf = Array2D<Vec<Vec<InputElem, 4>, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>;
|
using OutSwizzleBuf = Array2D<Vec<Vec<InputElem, 4>, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
static constexpr uint32_t nbKBuf = 2;
|
||||||
|
static constexpr uint32_t nbVBuf = 3; // @fixme: skip_softmax_attn: for skip softmax attn, an extra VBuffer is used
|
||||||
|
static constexpr uint32_t nbXBuf
|
||||||
|
= 3 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||||
|
#else
|
||||||
|
static constexpr uint32_t nbKBuf = 2;
|
||||||
|
static constexpr uint32_t nbVBuf = 2;
|
||||||
|
static constexpr uint32_t nbXBuf
|
||||||
|
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||||
|
#endif
|
||||||
static_assert(nbXBuf == nbVBuf);
|
static_assert(nbXBuf == nbVBuf);
|
||||||
|
|
||||||
|
// note: buffers used for GMMA may have additional alignment requirements
|
||||||
|
KBuffer k[nbKBuf]; // as is loaded from global mem.
|
||||||
|
QBuffer q; // For gmma math. Conversion done if needed.
|
||||||
|
|
||||||
union ReusedXVOutSwizzleBuf
|
union ReusedXVOutSwizzleBuf
|
||||||
{
|
{
|
||||||
struct XV
|
struct XV
|
||||||
@ -196,9 +212,6 @@ struct alignas(128) SharedMem
|
|||||||
return reusedXVOutSwizzleBuf[i].outSwizzle;
|
return reusedXVOutSwizzleBuf[i].outSwizzle;
|
||||||
}
|
}
|
||||||
|
|
||||||
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
|
|
||||||
QBuffer q; // For gmma math. Conversion done if needed.
|
|
||||||
|
|
||||||
// @fixme: move these into reusedXVOutSwizzleBuf
|
// @fixme: move these into reusedXVOutSwizzleBuf
|
||||||
#if SWAP_AB
|
#if SWAP_AB
|
||||||
ShmQWiseVec xColMax[nbXBuf];
|
ShmQWiseVec xColMax[nbXBuf];
|
||||||
@ -220,6 +233,11 @@ struct alignas(128) SharedMem
|
|||||||
Vec<KVCachePageIndex, nbPagesPerTile> pages[2]; // one for K and one for V
|
Vec<KVCachePageIndex, nbPagesPerTile> pages[2]; // one for K and one for V
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
uint32_t skipSoftmaxVotesGemm0ToV[nbXBuf]; // guarded by skipSoftmaxXBar
|
||||||
|
uint32_t skipSoftmaxVotesGemm0ToGemm1[nbXBuf]; // guarded by xBar
|
||||||
|
#endif
|
||||||
|
|
||||||
// mem barriers
|
// mem barriers
|
||||||
|
|
||||||
CtaBarrierPair qBar;
|
CtaBarrierPair qBar;
|
||||||
@ -229,6 +247,9 @@ struct alignas(128) SharedMem
|
|||||||
CtaBarrierPair vtBar[nbVBuf];
|
CtaBarrierPair vtBar[nbVBuf];
|
||||||
#endif
|
#endif
|
||||||
CtaBarrierPair xBar[nbXBuf];
|
CtaBarrierPair xBar[nbXBuf];
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
CtaBarrierPair skipSoftmaxXBar[nbXBuf]; // for V to wait for X to be ready
|
||||||
|
#endif
|
||||||
|
|
||||||
// used internally in the gemm0 warp group
|
// used internally in the gemm0 warp group
|
||||||
// @fixme: use separate arrive and wait for all usage
|
// @fixme: use separate arrive and wait for all usage
|
||||||
@ -425,8 +446,13 @@ __device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if SWAP_AB
|
#if SWAP_AB
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src,
|
||||||
|
float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip);
|
||||||
|
#else
|
||||||
__device__ RegColWiseVec computeWarpGrpColMax_sync(
|
__device__ RegColWiseVec computeWarpGrpColMax_sync(
|
||||||
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
|
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
|
||||||
|
#endif
|
||||||
__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd);
|
__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd);
|
||||||
__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax);
|
__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax);
|
||||||
__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
|
__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
|
||||||
@ -675,6 +701,12 @@ CUBIN_EXPORT __global__
|
|||||||
#endif
|
#endif
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
SpecDecParams const specDecParams,
|
SpecDecParams const specDecParams,
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
float const skipSoftmaxThresholdScaleFactor,
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
uint32_t* __restrict__ const semaphores
|
uint32_t* __restrict__ const semaphores
|
||||||
= nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
|
= nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
|
||||||
@ -753,6 +785,10 @@ CUBIN_EXPORT __global__
|
|||||||
uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
|
uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
|
||||||
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
|
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
|
||||||
assert(isMultiBlockMode == (nbSubSeq > 1));
|
assert(isMultiBlockMode == (nbSubSeq > 1));
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
bool const disableSkipForShortSeq = (cacheSeqLen < skipSoftmaxThresholdScaleFactor);
|
||||||
|
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / cacheSeqLen;
|
||||||
|
#endif
|
||||||
if (idxSubSeq >= nbSubSeq)
|
if (idxSubSeq >= nbSubSeq)
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
@ -776,21 +812,34 @@ CUBIN_EXPORT __global__
|
|||||||
assert(dynamicSmemSize() >= sizeof(SharedMem));
|
assert(dynamicSmemSize() >= sizeof(SharedMem));
|
||||||
SharedMem& smem = *reinterpret_cast<SharedMem*>(&smemByteBuf[0]);
|
SharedMem& smem = *reinterpret_cast<SharedMem*>(&smemByteBuf[0]);
|
||||||
|
|
||||||
constexpr uint32_t nbBuffers = 2;
|
constexpr uint32_t maxNbBuffers = (SharedMem::nbXBuf > SharedMem::nbVBuf) ? SharedMem::nbXBuf : SharedMem::nbVBuf;
|
||||||
static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf);
|
static_assert(
|
||||||
if (wid < nbBuffers)
|
maxNbBuffers >= SharedMem::nbKBuf && maxNbBuffers >= SharedMem::nbVBuf && maxNbBuffers >= SharedMem::nbXBuf);
|
||||||
|
if (wid < maxNbBuffers)
|
||||||
{
|
{
|
||||||
if (warpElectSync())
|
if (warpElectSync())
|
||||||
{
|
{
|
||||||
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
|
if (wid < SharedMem::nbKBuf)
|
||||||
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
|
{
|
||||||
#if !SWAP_AB
|
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
|
||||||
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
|
}
|
||||||
|
if (wid < SharedMem::nbXBuf)
|
||||||
|
{
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
smem.skipSoftmaxXBar[wid].initialize(gemm0NbThrds + warp_size, gemm0NbThrds + warp_size);
|
||||||
|
smem.vBar[wid].initialize(gemm1NbThrds + warp_size, gemm1NbThrds + warp_size);
|
||||||
|
#else
|
||||||
|
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
|
||||||
#endif
|
#endif
|
||||||
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
|
|
||||||
|
#if !SWAP_AB
|
||||||
|
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
|
||||||
|
#endif
|
||||||
|
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (wid == nbBuffers)
|
else if (wid == maxNbBuffers)
|
||||||
{
|
{
|
||||||
if (warpElectSync())
|
if (warpElectSync())
|
||||||
{
|
{
|
||||||
@ -819,6 +868,10 @@ CUBIN_EXPORT __global__
|
|||||||
SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen};
|
SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
uint32_t localSkippedBlockCount = 0;
|
||||||
|
#endif
|
||||||
|
|
||||||
// QK gemm
|
// QK gemm
|
||||||
constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM);
|
constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM);
|
||||||
using Acc = GmmaAcc<gemm0CtaTileNbTokens, ctaNbQHeads>;
|
using Acc = GmmaAcc<gemm0CtaTileNbTokens, ctaNbQHeads>;
|
||||||
@ -940,10 +993,39 @@ CUBIN_EXPORT __global__
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||||
|
auto& xBar = smem.xBar[idxXBuf];
|
||||||
// update colMax in shared mem and get a register copy
|
// update colMax in shared mem and get a register copy
|
||||||
#if SWAP_AB
|
#if SWAP_AB
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
|
||||||
|
skipSoftmaxXBar.consumed.arrive_and_wait();
|
||||||
|
|
||||||
|
bool const maybeSkip = !disableSkipForShortSeq && idxIter != 0;
|
||||||
|
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc,
|
||||||
|
skipSoftmaxThreshold, &smem.skipSoftmaxVotesGemm0ToV[idxXBuf], maybeSkip);
|
||||||
|
bool const shouldSkipSoftmaxAttn = static_cast<bool>(smem.skipSoftmaxVotesGemm0ToV[idxXBuf]);
|
||||||
|
unused(skipSoftmaxXBar.produced.arrive());
|
||||||
|
warpGrpOnlineSoftmax(acc, colMax);
|
||||||
|
if (shouldSkipSoftmaxAttn)
|
||||||
|
{
|
||||||
|
xBar.consumed.arrive_and_wait();
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 1U;
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
localSkippedBlockCount++;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used
|
||||||
|
unused(xBar.produced.arrive());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#else
|
||||||
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc);
|
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc);
|
||||||
warpGrpOnlineSoftmax(acc, colMax);
|
warpGrpOnlineSoftmax(acc, colMax);
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc);
|
RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc);
|
||||||
warpGrpOnlineSoftmax(acc, rowMax);
|
warpGrpOnlineSoftmax(acc, rowMax);
|
||||||
@ -959,8 +1041,6 @@ CUBIN_EXPORT __global__
|
|||||||
// map 1 to fp8_max before conversion to fp8
|
// map 1 to fp8_max before conversion to fp8
|
||||||
acc = acc * kE4M3_MAX;
|
acc = acc * kE4M3_MAX;
|
||||||
|
|
||||||
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
|
|
||||||
auto& xBar = smem.xBar[idxXBuf];
|
|
||||||
// @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM.
|
// @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM.
|
||||||
#if SWAP_AB
|
#if SWAP_AB
|
||||||
storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc);
|
storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc);
|
||||||
@ -989,13 +1069,25 @@ CUBIN_EXPORT __global__
|
|||||||
storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax);
|
storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax);
|
||||||
storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum);
|
storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum);
|
||||||
#endif
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
// the release semantics of arrive does not work for async consumers like gmma. additional fence is
|
// the release semantics of arrive does not work for async consumers like gmma. additional fence is
|
||||||
// needed.
|
// needed.
|
||||||
asm volatile("fence.proxy.async.shared::cta;\n");
|
asm volatile("fence.proxy.async.shared::cta;\n");
|
||||||
unused(xBar.produced.arrive());
|
unused(xBar.produced.arrive());
|
||||||
}
|
}
|
||||||
|
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
if (threadIdx.x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr)
|
||||||
|
{
|
||||||
|
atomicAdd(skippedBlockCount, localSkippedBlockCount);
|
||||||
|
atomicAdd(totalBlockCount, nbIters);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
unused(smem.qBar.consumed.arrive());
|
unused(smem.qBar.consumed.arrive());
|
||||||
}
|
}
|
||||||
else if (warpIdx.z == 1)
|
else if (warpIdx.z == 1)
|
||||||
@ -1043,216 +1135,233 @@ CUBIN_EXPORT __global__
|
|||||||
uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
||||||
auto const idxVBuf = idxIter % SharedMem::nbVBuf;
|
auto const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||||
auto const idxXBuf = idxVBuf;
|
auto const idxXBuf = idxVBuf;
|
||||||
|
auto& xBar = smem.xBar[idxXBuf];
|
||||||
auto& vBar = smem.vBar[idxVBuf];
|
auto& vBar = smem.vBar[idxVBuf];
|
||||||
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
|
|
||||||
auto const& vBuf = smem.vBuf(idxVBuf);
|
auto const& vBuf = smem.vBuf(idxVBuf);
|
||||||
#if !SWAP_AB
|
#if !SWAP_AB
|
||||||
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
|
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
|
||||||
auto& vtBuf = smem.vtBuf(idxVBuf);
|
auto& vtBuf = smem.vtBuf(idxVBuf);
|
||||||
vtBar.consumed.arrive_and_wait();
|
|
||||||
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
|
|
||||||
vBar.consumed.arrive();
|
|
||||||
vtBar.produced.arrive();
|
|
||||||
#endif
|
#endif
|
||||||
auto& xBar = smem.xBar[idxXBuf];
|
|
||||||
xBar.produced.arrive_and_wait();
|
xBar.produced.arrive_and_wait();
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf]; // guarded by xBar
|
||||||
|
if (shouldSkipSoftmaxAttn)
|
||||||
|
{
|
||||||
|
vBar.produced.arrive_and_wait();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
if (!shouldSkipSoftmaxAttn) // skip XVGemm
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
|
||||||
|
#if !SWAP_AB
|
||||||
|
vtBar.consumed.arrive_and_wait();
|
||||||
|
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
|
||||||
|
vBar.consumed.arrive();
|
||||||
|
vtBar.produced.arrive();
|
||||||
|
#endif
|
||||||
#if !defined(NDEBUG) && DBG_PRINT
|
#if !defined(NDEBUG) && DBG_PRINT
|
||||||
#if SWAP_AB
|
#if SWAP_AB
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
{
|
|
||||||
printf("colMax:\n");
|
|
||||||
for (int i = 0; i < ctaNbQHeads; i++)
|
|
||||||
{
|
|
||||||
printf("%f, ", smem.xColMax[idxXBuf][i]);
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
printf("colSum:\n");
|
|
||||||
for (int n = 0; n < 4; n++)
|
|
||||||
{
|
{
|
||||||
|
printf("colMax:\n");
|
||||||
for (int i = 0; i < ctaNbQHeads; i++)
|
for (int i = 0; i < ctaNbQHeads; i++)
|
||||||
{
|
{
|
||||||
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
|
printf("%f, ", smem.xColMax[idxXBuf][i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
printf("colSum:\n");
|
||||||
|
for (int n = 0; n < 4; n++)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < ctaNbQHeads; i++)
|
||||||
|
{
|
||||||
|
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
printf("X:\n");
|
||||||
|
for (int i = 0; i < ctaNbQHeads; i++)
|
||||||
|
{
|
||||||
|
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
|
||||||
|
{
|
||||||
|
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
|
||||||
|
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
|
||||||
|
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
|
||||||
|
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
|
||||||
|
printf("%.2f, ", float(e));
|
||||||
|
if (j % 16 == 15)
|
||||||
|
{
|
||||||
|
printf("| ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||||
|
#else
|
||||||
|
if (blockIdx.y == 1 && threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
printf("rowMax:\n");
|
||||||
|
for (int i = 0; i < ctaNbQHeads; i++)
|
||||||
|
{
|
||||||
|
printf("%f, ", smem.xRowMax[idxXBuf][i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
printf("rowSum:\n");
|
||||||
|
for (int i = 0; i < ctaNbQHeads; i++)
|
||||||
|
{
|
||||||
|
printf("%f, ", smem.xRowSum[idxXBuf][i]);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
printf("\n");
|
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||||
printf("X:\n");
|
|
||||||
for (int i = 0; i < ctaNbQHeads; i++)
|
|
||||||
{
|
|
||||||
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
|
|
||||||
{
|
|
||||||
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
|
|
||||||
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
|
|
||||||
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
|
|
||||||
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
|
|
||||||
printf("%.2f, ", float(e));
|
|
||||||
if (j % 16 == 15)
|
|
||||||
{
|
|
||||||
printf("| ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
printf("\n\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
|
||||||
#else
|
|
||||||
if (blockIdx.y == 1 && threadIdx.x == 0)
|
|
||||||
{
|
|
||||||
printf("rowMax:\n");
|
|
||||||
for (int i = 0; i < ctaNbQHeads; i++)
|
|
||||||
{
|
|
||||||
printf("%f, ", smem.xRowMax[idxXBuf][i]);
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
printf("rowSum:\n");
|
|
||||||
for (int i = 0; i < ctaNbQHeads; i++)
|
|
||||||
{
|
|
||||||
printf("%f, ", smem.xRowSum[idxXBuf][i]);
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
}
|
|
||||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if SWAP_AB
|
#if SWAP_AB
|
||||||
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
|
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
|
||||||
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
|
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
|
||||||
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
|
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
|
||||||
#else
|
#else
|
||||||
rescaleGemm1AccForNewRowMax_sync(
|
rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf],
|
||||||
warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
|
smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
|
||||||
#endif
|
#endif
|
||||||
auto& xBuf = smem.xBuf(idxXBuf);
|
auto& xBuf = smem.xBuf(idxXBuf);
|
||||||
|
|
||||||
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
|
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
|
||||||
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
|
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
|
||||||
.raw();
|
.raw();
|
||||||
#if CACHE_ELEM_ENUM == 0
|
#if CACHE_ELEM_ENUM == 0
|
||||||
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
|
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
|
||||||
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
|
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
|
||||||
.raw();
|
.raw();
|
||||||
#endif
|
#endif
|
||||||
#if SWAP_AB
|
#if SWAP_AB
|
||||||
//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed.
|
//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
|
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
|
||||||
{
|
{
|
||||||
#if CACHE_ELEM_ENUM == 2
|
#if CACHE_ELEM_ENUM == 2
|
||||||
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
|
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
|
||||||
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
|
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
|
||||||
#if !defined(NDEBUG) && DBG_PRINT
|
#if !defined(NDEBUG) && DBG_PRINT
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
{
|
|
||||||
printf("fragA:\nidxInstK == %u\n", idxInstK);
|
|
||||||
}
|
|
||||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
|
||||||
for (int m = 0; m < 2; m++)
|
|
||||||
{
|
|
||||||
for (int w = 0; w < 4; w++)
|
|
||||||
{
|
{
|
||||||
if (warpRank == w)
|
printf("fragA:\nidxInstK == %u\n", idxInstK);
|
||||||
|
}
|
||||||
|
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||||
|
for (int m = 0; m < 2; m++)
|
||||||
|
{
|
||||||
|
for (int w = 0; w < 4; w++)
|
||||||
{
|
{
|
||||||
if (laneId() == 0)
|
if (warpRank == w)
|
||||||
{
|
{
|
||||||
printf(" warpRank = %u\n", warpRank);
|
if (laneId() == 0)
|
||||||
}
|
|
||||||
__syncwarp();
|
|
||||||
for (int a = 0; a < 2; a++)
|
|
||||||
{
|
|
||||||
for (int b = 0; b < 8; b++)
|
|
||||||
{
|
{
|
||||||
for (int c = 0; c < 2; c++)
|
printf(" warpRank = %u\n", warpRank);
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
|
for (int a = 0; a < 2; a++)
|
||||||
|
{
|
||||||
|
for (int b = 0; b < 8; b++)
|
||||||
{
|
{
|
||||||
for (int d = 0; d < 4; d++)
|
for (int c = 0; c < 2; c++)
|
||||||
{
|
{
|
||||||
if (laneId() == b * 4 + d)
|
for (int d = 0; d < 4; d++)
|
||||||
{
|
{
|
||||||
for (int e = 0; e < 4; e++)
|
if (laneId() == b * 4 + d)
|
||||||
{
|
{
|
||||||
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
|
for (int e = 0; e < 4; e++)
|
||||||
fragA[m](0, c)(a, 0));
|
{
|
||||||
printf("%.2f, ", float(elem4[e]));
|
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
|
||||||
|
fragA[m](0, c)(a, 0));
|
||||||
|
printf("%.2f, ", float(elem4[e]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
__syncwarp();
|
||||||
}
|
}
|
||||||
__syncwarp();
|
|
||||||
}
|
}
|
||||||
|
if (laneId() == 0)
|
||||||
|
{
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
}
|
}
|
||||||
if (laneId() == 0)
|
if (laneId() == 0 && a == 0)
|
||||||
{
|
{
|
||||||
printf("\n");
|
printf("----------------------\n");
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
}
|
}
|
||||||
if (laneId() == 0 && a == 0)
|
|
||||||
{
|
|
||||||
printf("----------------------\n");
|
|
||||||
}
|
|
||||||
__syncwarp();
|
|
||||||
}
|
}
|
||||||
|
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||||
}
|
}
|
||||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
|
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
|
||||||
auto const descX = addAddr(descXBase,
|
auto const descX = addAddr(descXBase,
|
||||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||||
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||||
#if CACHE_ELEM_ENUM == 2
|
#if CACHE_ELEM_ENUM == 2
|
||||||
gmma::fence();
|
gmma::fence();
|
||||||
#endif
|
#endif
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
|
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
|
||||||
{
|
{
|
||||||
#if CACHE_ELEM_ENUM == 0
|
#if CACHE_ELEM_ENUM == 0
|
||||||
auto const descV
|
auto const descV
|
||||||
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
|
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
|
||||||
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
|
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
|
||||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||||
descV, descX, true);
|
descV, descX, true);
|
||||||
#elif CACHE_ELEM_ENUM == 2
|
#elif CACHE_ELEM_ENUM == 2
|
||||||
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
|
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
|
||||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||||
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
|
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
gmma::commit_group();
|
||||||
|
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||||
|
// gmma.
|
||||||
|
gmma::wait_group<0>();
|
||||||
}
|
}
|
||||||
gmma::commit_group();
|
|
||||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
|
||||||
// gmma.
|
|
||||||
gmma::wait_group<0>();
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
auto const descVTBase = gmma::makeMatDesc(
|
auto const descVTBase = gmma::makeMatDesc(
|
||||||
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
|
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
|
||||||
.raw();
|
.raw();
|
||||||
vtBar.produced.arrive_and_wait();
|
vtBar.produced.arrive_and_wait();
|
||||||
// if (idxIter == 1 && threadIdx.x == 0) {
|
// if (idxIter == 1 && threadIdx.x == 0) {
|
||||||
// printf("vtBuf:\n");
|
// printf("vtBuf:\n");
|
||||||
// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf);
|
// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf);
|
||||||
// }
|
// }
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
|
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
|
||||||
{
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
|
|
||||||
{
|
{
|
||||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
|
#pragma unroll
|
||||||
auto const descX = addAddr(descXBase,
|
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
|
||||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
{
|
||||||
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
|
||||||
auto const descVT = addAddr(
|
auto const descX = addAddr(descXBase,
|
||||||
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
|
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||||
gmma::mma_async_shmA<MathElem, headElems>(
|
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||||
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
|
auto const descVT = addAddr(
|
||||||
descVT, true);
|
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
|
||||||
|
gmma::mma_async_shmA<MathElem, headElems>(
|
||||||
|
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
|
||||||
|
descVT, true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
gmma::commit_group();
|
||||||
gmma::commit_group();
|
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma.
|
// gmma.
|
||||||
gmma::wait_group<0>();
|
gmma::wait_group<0>();
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
if (idxIter == nbIters - 1)
|
if (idxIter == nbIters - 1)
|
||||||
{
|
{
|
||||||
// gmma::wait_group should have already synchronized threads, so this may be unnecessary.
|
// gmma::wait_group should have already synchronized threads, so this may be unnecessary.
|
||||||
@ -1471,8 +1580,24 @@ CUBIN_EXPORT __global__
|
|||||||
tensorMap
|
tensorMap
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
for (auto& b : smem.skipSoftmaxXBar)
|
||||||
|
{
|
||||||
|
unused(b.consumed.arrive());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++)
|
for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++)
|
||||||
{
|
{
|
||||||
|
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||||
|
auto& vBar = smem.vBar[idxVBuf];
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
uint32_t idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||||
|
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
|
||||||
|
skipSoftmaxXBar.produced.arrive_and_wait();
|
||||||
|
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToV[idxXBuf];
|
||||||
|
skipSoftmaxXBar.consumed.arrive();
|
||||||
|
#endif
|
||||||
|
|
||||||
uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
||||||
vTileLoader.loadPages(idxVTile);
|
vTileLoader.loadPages(idxVTile);
|
||||||
#if USE_INPUT_KV || ENABLE_PDL == 2
|
#if USE_INPUT_KV || ENABLE_PDL == 2
|
||||||
@ -1506,8 +1631,20 @@ CUBIN_EXPORT __global__
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
|
#if SKIP_SOFTMAX_ATTN
|
||||||
auto& vBar = smem.vBar[idxVBuf];
|
if (shouldSkipSoftmaxAttn)
|
||||||
|
{
|
||||||
|
vBar.consumed.arrive_and_wait();
|
||||||
|
// compared to non-skip softmax attn, we need to increase vBar.produced count to avoid race
|
||||||
|
// condition where vBar.consumed is arrived again without wait without skip softmax attn, XVGemm
|
||||||
|
// will wait for tx_count, so its progress won't go ahead of vload warp with skip softmax attn,
|
||||||
|
// XVGemm WG may go ahead of vload warp, as previous vBar only have XVGemm WG threads and a tx_count
|
||||||
|
// (now = 0). Then it may arrive vBar.consumed before it is arrive_and_wait-ed
|
||||||
|
vBar.produced.arrive();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
vBar.consumed.arrive_and_wait();
|
vBar.consumed.arrive_and_wait();
|
||||||
if (warpElectSync())
|
if (warpElectSync())
|
||||||
{
|
{
|
||||||
@ -1517,6 +1654,9 @@ CUBIN_EXPORT __global__
|
|||||||
vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced);
|
vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
vBar.produced.arrive();
|
||||||
|
#endif
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1992,9 +2132,23 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
|||||||
#endif // SPEC_DEC
|
#endif // SPEC_DEC
|
||||||
|
|
||||||
// smemColMax is persistent across multiple iterations
|
// smemColMax is persistent across multiple iterations
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax,
|
||||||
|
Gemm0Acc const& src, float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip)
|
||||||
|
#else
|
||||||
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||||
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src)
|
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
*smemSkipVote = maybeSkip ? 1U : 0U; // will sync before vote
|
||||||
|
}
|
||||||
|
float const lnThreshold
|
||||||
|
= log(skipSoftmaxThreshold); // this can be -inf, but should be safe as we only use it for comparison
|
||||||
|
#endif
|
||||||
|
|
||||||
auto colMax = RegColWiseVec::filled(Vec<float, 2>::filled(safeInitRowMax));
|
auto colMax = RegColWiseVec::filled(Vec<float, 2>::filled(safeInitRowMax));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t n = 0; n < src.cols; n++)
|
for (uint32_t n = 0; n < src.cols; n++)
|
||||||
@ -2029,6 +2183,9 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint32_t const lane = laneId();
|
uint32_t const lane = laneId();
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
auto prevOrCurrentMax = RegColWiseVec();
|
||||||
|
#if SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||||
if (lane < 4)
|
if (lane < 4)
|
||||||
{
|
{
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -2037,12 +2194,43 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < 2; j++)
|
for (uint32_t j = 0; j < 2; j++)
|
||||||
{
|
{
|
||||||
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
prevOrCurrentMax[n][j] = smemColMax[8 * n + 2 * lane + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
warpGrpBar.arrive_and_wait();
|
warpGrpBar.arrive_and_wait();
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (lane < 4)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t n = 0; n < src.cols; n++)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < 2; j++)
|
||||||
|
{
|
||||||
|
#if SKIP_SOFTMAX_ATTN && !SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||||
|
// prevOrCurrentMax <= actual smemColMax (after updates from all 4 warps done), but always >=
|
||||||
|
// smemColMax(Prev), the smemColMax value *before* this tile is computed.
|
||||||
|
// When determine whether to skip, it is safe to use prevOrCurrentMax: 1) all 4 warps' localmax <
|
||||||
|
// smemColMax(Prev), then prevOrCurrentMax == smemColMax(Prev), result not affected; 2) if some localmax
|
||||||
|
// > smemColMax(Prev), prevOrCurrentMax > smemColMax(Prev), some warps may incorrectly vote skip, but
|
||||||
|
// at least one warp whose localColMax is larger will not skip, then the tile is not skipped.
|
||||||
|
// This reduces some sync and check, but has issue when threshold > 1.
|
||||||
|
prevOrCurrentMax[n][j] = atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||||
|
#else
|
||||||
|
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warpGrpBar.arrive_and_wait();
|
||||||
|
|
||||||
uint32_t const idxInQuad = lane % 4;
|
uint32_t const idxInQuad = lane % 4;
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
bool localShouldSkip = true;
|
||||||
|
#endif
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t n = 0; n < src.cols; n++)
|
for (uint32_t n = 0; n < src.cols; n++)
|
||||||
@ -2050,10 +2238,21 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
|
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
|
||||||
{
|
{
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
if (lane < 4 && 8 * n + 2 * idxInQuad + j < headGrpSize)
|
||||||
|
{
|
||||||
|
localShouldSkip &= (colMax[n][j] - prevOrCurrentMax[n][j]) < lnThreshold;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]);
|
assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]);
|
||||||
colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j];
|
colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
atomicAnd(smemSkipVote, static_cast<uint32_t>(localShouldSkip)); // this will be translated to redux and voteu
|
||||||
|
#endif
|
||||||
|
|
||||||
warpGrpBar.arrive_and_wait();
|
warpGrpBar.arrive_and_wait();
|
||||||
return colMax;
|
return colMax;
|
||||||
}
|
}
|
||||||
@ -2199,7 +2398,7 @@ __device__ inline void storeGemm0AccToShm(
|
|||||||
uint32_t const idxOctInsideHalf = idxInHalf / 8;
|
uint32_t const idxOctInsideHalf = idxInHalf / 8;
|
||||||
uint32_t const idxRowInsideOct = lane % 8;
|
uint32_t const idxRowInsideOct = lane % 8;
|
||||||
uint32_t const warpBaseC = 16 * warpRank;
|
uint32_t const warpBaseC = 16 * warpRank;
|
||||||
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair<uint32_t, uint32_t>
|
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> mha::pair<uint32_t, uint32_t>
|
||||||
{
|
{
|
||||||
uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols;
|
uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols;
|
||||||
uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols;
|
uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols;
|
||||||
@ -3231,6 +3430,24 @@ __device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifndef GENERATE_CUBIN
|
#ifndef GENERATE_CUBIN
|
||||||
|
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
|
||||||
|
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
|
||||||
|
{
|
||||||
|
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||||
|
if (env != nullptr)
|
||||||
|
{
|
||||||
|
int32_t const val = std::stoi(env);
|
||||||
|
if (val > 0)
|
||||||
|
{
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float const factor = 0.25f;
|
||||||
|
return mha::min<uint32_t>(
|
||||||
|
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
|
||||||
|
divUp(maxSeqLen, gemm0CtaTileNbTokens));
|
||||||
|
}
|
||||||
|
|
||||||
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||||
#if SLIDING_WINDOW
|
#if SLIDING_WINDOW
|
||||||
uint32_t slidingWinSize,
|
uint32_t slidingWinSize,
|
||||||
@ -3268,6 +3485,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
|||||||
// int8/fp8 KV cache.
|
// int8/fp8 KV cache.
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
SpecDecParams const& specDecParams,
|
SpecDecParams const& specDecParams,
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
float const skipSoftmaxThresholdScaleFactor,
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
@ -3286,22 +3509,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
|||||||
uint32_t const nbVHeads = nbKHeads;
|
uint32_t const nbVHeads = nbKHeads;
|
||||||
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
||||||
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
|
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
|
||||||
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
|
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqHopperF8MHA(prop, batchSize, nbKHeads, maxSeqLen);
|
||||||
{
|
|
||||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
|
||||||
if (env != nullptr)
|
|
||||||
{
|
|
||||||
int32_t const val = std::stoi(env);
|
|
||||||
if (val > 0)
|
|
||||||
{
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
float const factor = 0.25f;
|
|
||||||
return mha::min<uint32_t>(
|
|
||||||
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
|
|
||||||
divUp(maxSeqLen, gemm0CtaTileNbTokens));
|
|
||||||
}();
|
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
uint32_t const qSeqLen = specDecParams.qSeqLen;
|
uint32_t const qSeqLen = specDecParams.qSeqLen;
|
||||||
#else
|
#else
|
||||||
@ -3371,6 +3579,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
|||||||
#endif
|
#endif
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
specDecParams,
|
specDecParams,
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
skipSoftmaxThresholdScaleFactor,
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
skippedBlockCount, totalBlockCount,
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
semaphores, scratch);
|
semaphores, scratch);
|
||||||
#else
|
#else
|
||||||
|
|||||||
@ -1272,6 +1272,19 @@ using is_void = is_same<remove_cv_t<T>, void>;
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
inline constexpr bool is_void_v = is_void<T>::value;
|
inline constexpr bool is_void_v = is_void<T>::value;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef GENERATE_CUBIN
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
using pair = std::pair<T1, T2>;
|
||||||
|
#else
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
struct pair
|
||||||
|
{
|
||||||
|
T1 first;
|
||||||
|
T2 second;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace mha
|
} // namespace mha
|
||||||
|
|
||||||
#if GENERATE_CUBIN
|
#if GENERATE_CUBIN
|
||||||
|
|||||||
@ -50,7 +50,8 @@ using Vector = Matrix<Type, Size, 1>;
|
|||||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
|
||||||
|
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
|
||||||
{
|
{
|
||||||
uint32_t const nbTiles = divUp(seqLen, tileSize);
|
uint32_t const nbTiles = divUp(seqLen, tileSize);
|
||||||
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
|
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
|
||||||
@ -61,6 +62,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
|||||||
float const qkScale = qScale * kvScale / sqrtf(validElemsPerHead);
|
float const qkScale = qScale * kvScale / sqrtf(validElemsPerHead);
|
||||||
uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize);
|
uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize);
|
||||||
uint32_t const idxTileBeg = seqBeg / tileSize;
|
uint32_t const idxTileBeg = seqBeg / tileSize;
|
||||||
|
|
||||||
|
uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1;
|
||||||
|
std::vector<Eigen::Vector<float, headGrpSize>> skipRowMaxs(nbSubSeq);
|
||||||
|
for (uint32_t i = 0; i < nbSubSeq; i++)
|
||||||
|
{
|
||||||
|
skipRowMaxs[i].fill(-INFINITY);
|
||||||
|
}
|
||||||
|
bool const disableSkipForShortSeq = (seqLen < skipSoftmaxThresholdScaleFactor);
|
||||||
|
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / seqLen;
|
||||||
|
|
||||||
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
|
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
|
||||||
{
|
{
|
||||||
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> gemm0Acc;
|
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> gemm0Acc;
|
||||||
@ -88,7 +99,22 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Eigen::Vector<float, headGrpSize> const tileRowMax = gemm0Acc.rowwise().maxCoeff().cwiseMax(rowMax).eval();
|
Eigen::Vector<float, headGrpSize> const localRowMax = gemm0Acc.rowwise().maxCoeff().eval();
|
||||||
|
Eigen::Vector<float, headGrpSize> const tileRowMax = localRowMax.cwiseMax(rowMax).eval();
|
||||||
|
auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq];
|
||||||
|
skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval();
|
||||||
|
|
||||||
|
if (!disableSkipForShortSeq && skipSoftmaxThreshold > 0)
|
||||||
|
{
|
||||||
|
*totalBlockCount += 1;
|
||||||
|
auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold));
|
||||||
|
bool const skipBlock = skipSoftmaxMask.all() && ((idxTile - idxTileBeg) >= nbSubSeq);
|
||||||
|
if (skipBlock)
|
||||||
|
{
|
||||||
|
*skippedBlockCount += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> tileX
|
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> tileX
|
||||||
= (gemm0Acc.colwise() - tileRowMax).array().exp().eval();
|
= (gemm0Acc.colwise() - tileRowMax).array().exp().eval();
|
||||||
@ -138,7 +164,8 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
|||||||
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
||||||
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
|
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
|
||||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
|
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
|
||||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, \
|
||||||
|
float skipSoftmaxThreshold, uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
|
||||||
|
|
||||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
|
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
|
||||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);
|
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);
|
||||||
|
|||||||
@ -88,7 +88,8 @@ struct CacheSeq<true, true>
|
|||||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
|
||||||
|
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum);
|
||||||
|
|
||||||
template <typename MathElem, bool isPaged, bool useBeamSearch>
|
template <typename MathElem, bool isPaged, bool useBeamSearch>
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
|
|||||||
@ -150,7 +150,8 @@ template <uint32_t nbKHeads>
|
|||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
|
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
|
||||||
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
|
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30,
|
||||||
|
float skipSoftmaxThresholdScaleFactor = 0.0f)
|
||||||
{
|
{
|
||||||
#if IS_MLA
|
#if IS_MLA
|
||||||
if (nbKHeads != 1)
|
if (nbKHeads != 1)
|
||||||
@ -224,6 +225,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head.
|
seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head.
|
||||||
}
|
}
|
||||||
ctxLen = std::min(ctxLen, seqLen);
|
ctxLen = std::min(ctxLen, seqLen);
|
||||||
|
uint32_t skippedBlockCount = 0;
|
||||||
|
uint32_t totalBlockCount = 0;
|
||||||
|
if (skipSoftmaxThresholdScaleFactor > 0)
|
||||||
|
{
|
||||||
|
assert(useQGMMA);
|
||||||
|
}
|
||||||
float const kScale = cacheElemSize == 2 ? 1.f : 1 / 4.f;
|
float const kScale = cacheElemSize == 2 ? 1.f : 1 / 4.f;
|
||||||
float const vScale = kScale;
|
float const vScale = kScale;
|
||||||
float const qScale = 1.f;
|
float const qScale = 1.f;
|
||||||
@ -329,6 +336,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
auto const rcpOutScale = ManagedMemBuf<float>(1);
|
auto const rcpOutScale = ManagedMemBuf<float>(1);
|
||||||
auto const seqLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
auto const seqLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
||||||
auto const ctxLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
auto const ctxLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
auto const kernelSkippedBlockCount = ManagedMemBuf<uint32_t>(1);
|
||||||
|
auto const kernelTotalBlockCount = ManagedMemBuf<uint32_t>(1);
|
||||||
|
kernelSkippedBlockCount[0] = 0;
|
||||||
|
kernelTotalBlockCount[0] = 0;
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
EXPECT_EQ(skipSoftmaxThresholdScaleFactor, 0.0f)
|
||||||
|
<< "Got non-zero skipSoftmaxThresholdScaleFactor while SKIP_SOFTMAX_ATTN is not enabled.";
|
||||||
|
#endif
|
||||||
#if USE_PAGED_KV_CACHE
|
#if USE_PAGED_KV_CACHE
|
||||||
auto const pageListBuf = ManagedMemBuf<std::byte>(pageListBytes);
|
auto const pageListBuf = ManagedMemBuf<std::byte>(pageListBytes);
|
||||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||||
@ -726,6 +744,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
maxSeqLen, &seqLenList[0][0], batchSize, kvCacheScale.get(), semaphores.get(), scratch, stream);
|
maxSeqLen, &seqLenList[0][0], batchSize, kvCacheScale.get(), semaphores.get(), scratch, stream);
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
|
auto multiBlockNum = [&]()
|
||||||
|
{
|
||||||
|
auto const calcFunc = useQGMMA ? &computeNbSubSeqPerSeqHopperF8MHA : &computeNbSubSeqPerSeqMHA;
|
||||||
|
return calcFunc(prop, batchSize, nbKHeads, maxSeqLen);
|
||||||
|
}();
|
||||||
auto runKernel = [&]()
|
auto runKernel = [&]()
|
||||||
{
|
{
|
||||||
auto const launchFunc = useQGMMA ? &launchHopperF8MHA : &launchMHA;
|
auto const launchFunc = useQGMMA ? &launchHopperF8MHA : &launchMHA;
|
||||||
@ -776,6 +799,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
batchSize, kvCacheScale.get(),
|
batchSize, kvCacheScale.get(),
|
||||||
#if SPEC_DEC
|
#if SPEC_DEC
|
||||||
specDecParams,
|
specDecParams,
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
skipSoftmaxThresholdScaleFactor,
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
kernelSkippedBlockCount.get(), kernelTotalBlockCount.get(),
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
semaphores.get(), scratch, stream);
|
semaphores.get(), scratch, stream);
|
||||||
checkCuda(cudaGetLastError());
|
checkCuda(cudaGetLastError());
|
||||||
@ -813,6 +842,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
checkCuda(cudaEventRecord(toc, stream));
|
checkCuda(cudaEventRecord(toc, stream));
|
||||||
prefetchToDevice(cudaCpuDeviceId);
|
prefetchToDevice(cudaCpuDeviceId);
|
||||||
checkCuda(cudaStreamSynchronize(stream));
|
checkCuda(cudaStreamSynchronize(stream));
|
||||||
|
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
kernelSkippedBlockCount[0] /= nbIters;
|
||||||
|
kernelTotalBlockCount[0] /= nbIters;
|
||||||
|
#endif
|
||||||
if (testPerf)
|
if (testPerf)
|
||||||
{
|
{
|
||||||
float ms;
|
float ms;
|
||||||
@ -849,6 +882,15 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
= totalNbCacheLoadBytes + inputBytes + outputBytes; // we ignore page indices and beam search indices.
|
= totalNbCacheLoadBytes + inputBytes + outputBytes; // we ignore page indices and beam search indices.
|
||||||
float const dramSolTime = totalTraffic / bandwidth * 1E3f;
|
float const dramSolTime = totalTraffic / bandwidth * 1E3f;
|
||||||
float const dramSolRatio = dramSolTime / ms;
|
float const dramSolRatio = dramSolTime / ms;
|
||||||
|
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
|
||||||
|
* (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]))
|
||||||
|
* nbLoadedCacheTokens;
|
||||||
|
float const totalTrafficWithSkip
|
||||||
|
= totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices.
|
||||||
|
float const dramSolTimeWithSkip = totalTrafficWithSkip / bandwidth * 1E3f;
|
||||||
|
float const dramSolRatioWithSkip = dramSolTimeWithSkip / ms;
|
||||||
|
#endif
|
||||||
if (verbose)
|
if (verbose)
|
||||||
{
|
{
|
||||||
printf("done\n");
|
printf("done\n");
|
||||||
@ -863,7 +905,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
}
|
}
|
||||||
float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
|
float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
|
||||||
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
|
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
|
||||||
|
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
|
||||||
|
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
|
||||||
|
printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, tops);
|
||||||
|
#else
|
||||||
printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops);
|
printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
if (refCheck)
|
if (refCheck)
|
||||||
{
|
{
|
||||||
@ -1084,8 +1132,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
if (useQGMMA)
|
if (useQGMMA)
|
||||||
{
|
{
|
||||||
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
|
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks,
|
||||||
refAttentionSinks);
|
skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum);
|
||||||
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||||
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||||
}
|
}
|
||||||
@ -1132,6 +1180,14 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
printf("host skippedBlockCount: %d/%d (%.2f%%)\n", skippedBlockCount, totalBlockCount,
|
||||||
|
totalBlockCount == 0 ? 0.0f : 100.0f * skippedBlockCount / totalBlockCount);
|
||||||
|
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||||
|
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
|
||||||
|
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
if (saveData)
|
if (saveData)
|
||||||
{
|
{
|
||||||
fout_refOutput.close();
|
fout_refOutput.close();
|
||||||
@ -1253,6 +1309,14 @@ TEST(RefCheck, llama_V2_70b)
|
|||||||
#if SLIDING_WINDOW
|
#if SLIDING_WINDOW
|
||||||
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
|
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
|
||||||
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
|
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
|
||||||
|
#endif
|
||||||
|
#if SKIP_SOFTMAX_ATTN
|
||||||
|
runTest<1>(32, 2048, false, true, false, false, false, ~0U, 1U << 30, 0.f);
|
||||||
|
runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 1280.f);
|
||||||
|
runTest<2>(32, 4096, false, true, false, false, false, ~0U, 1U << 30, 125.f);
|
||||||
|
runTest<4>(32, 300, false, true, false, false, false, ~0U, 1U << 30, 80.f);
|
||||||
|
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 501.0f);
|
||||||
|
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 500.f);
|
||||||
#endif
|
#endif
|
||||||
runTest<8>(120, 367, false, true);
|
runTest<8>(120, 367, false, true);
|
||||||
runTest<8>(1792, 2048, false, true);
|
runTest<8>(1792, 2048, false, true);
|
||||||
|
|||||||
@ -157,6 +157,11 @@ set(UCX_WRAPPER_TARGET tensorrt_llm_ucx_wrapper)
|
|||||||
|
|
||||||
if(NIXL_ROOT)
|
if(NIXL_ROOT)
|
||||||
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
|
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
|
||||||
|
set(TRANSFER_AGENT_BINDING_TARGET tensorrt_llm_transfer_agent_binding)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MOONCAKE_ROOT)
|
||||||
|
set(MOONCAKE_WRAPPER_TARGET tensorrt_llm_mooncake_wrapper)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(executor)
|
add_subdirectory(executor)
|
||||||
@ -272,6 +277,11 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
|
|||||||
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
|
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(TARGET ${MOONCAKE_WRAPPER_TARGET})
|
||||||
|
target_link_libraries(${MOONCAKE_WRAPPER_TARGET} INTERFACE ${SHARED_TARGET})
|
||||||
|
add_dependencies(${SHARED_TARGET} ${MOONCAKE_WRAPPER_TARGET})
|
||||||
|
endif()
|
||||||
|
|
||||||
if(NOT WIN32)
|
if(NOT WIN32)
|
||||||
# Load libraries at $PREFIX/lib from
|
# Load libraries at $PREFIX/lib from
|
||||||
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs
|
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs
|
||||||
@ -283,13 +293,7 @@ if(BUILD_PYT)
|
|||||||
add_subdirectory(thop)
|
add_subdirectory(thop)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(BINDING_TYPE STREQUAL "pybind")
|
add_subdirectory(nanobind)
|
||||||
add_subdirectory(pybind)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(BINDING_TYPE STREQUAL "nanobind")
|
|
||||||
add_subdirectory(nanobind)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(BUILD_DEEP_EP)
|
if(BUILD_DEEP_EP)
|
||||||
add_subdirectory(deep_ep)
|
add_subdirectory(deep_ep)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ set(TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
|||||||
set(SRCS
|
set(SRCS
|
||||||
allocateKvCache.cpp
|
allocateKvCache.cpp
|
||||||
assignReqSeqSlots.cpp
|
assignReqSeqSlots.cpp
|
||||||
|
baseTransBuffer.cpp
|
||||||
cacheFormatter.cpp
|
cacheFormatter.cpp
|
||||||
mlaCacheFormatter.cpp
|
mlaCacheFormatter.cpp
|
||||||
cacheTransceiver.cpp
|
cacheTransceiver.cpp
|
||||||
@ -36,6 +37,8 @@ set(SRCS
|
|||||||
kvCacheManager.cpp
|
kvCacheManager.cpp
|
||||||
kvCacheEventManager.cpp
|
kvCacheEventManager.cpp
|
||||||
kvCacheTransferManager.cpp
|
kvCacheTransferManager.cpp
|
||||||
|
kvCacheManagerV2Utils.cpp
|
||||||
|
kvCacheManagerV2Utils.cu
|
||||||
llmRequest.cpp
|
llmRequest.cpp
|
||||||
logitsPostProcessor.cpp
|
logitsPostProcessor.cpp
|
||||||
loraBuffers.cpp
|
loraBuffers.cpp
|
||||||
|
|||||||
285
cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp
Normal file
285
cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "baseTransBuffer.h"
|
||||||
|
#include "cacheTransBuffer.h"
|
||||||
|
#include "tensorrt_llm/common/envUtils.h"
|
||||||
|
#include "tensorrt_llm/common/logger.h"
|
||||||
|
#include "tensorrt_llm/common/opUtils.h"
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace tensorrt_llm::batch_manager
|
||||||
|
{
|
||||||
|
|
||||||
|
BaseTransBufferManager::BaseTransBufferManager(
|
||||||
|
size_t transferBufferSize, nvinfer1::DataType dataType, std::optional<size_t> maxNumTokens)
|
||||||
|
: mDataType{dataType}
|
||||||
|
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
||||||
|
, mMaxNumTokens{maxNumTokens}
|
||||||
|
{
|
||||||
|
mTransferBufferSize = transferBufferSize;
|
||||||
|
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
|
||||||
|
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
|
||||||
|
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
|
||||||
|
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
|
||||||
|
&& kv_cache_manager::FabricMemory::supportFbaricMemory();
|
||||||
|
if (mUseFabricMemory)
|
||||||
|
{
|
||||||
|
mTransferBufferSize = kv_cache_manager::FabricMemory::getAlignedSize(mTransferBufferSize);
|
||||||
|
}
|
||||||
|
mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount);
|
||||||
|
|
||||||
|
TLLM_LOG_INFO(
|
||||||
|
"BaseTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
|
||||||
|
"mSendBufferCount:%ld, mTransferBufferSize:%ld, mPreAllocBufferSize:%ld, mOnlyUseDynamicBuffer:%d, "
|
||||||
|
"mUseFabricMemory:%d, mDataType:%d",
|
||||||
|
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
|
||||||
|
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, static_cast<int>(mDataType));
|
||||||
|
|
||||||
|
allocateBuffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<int> BaseTransBufferManager::assignBufferIndexForSend()
|
||||||
|
{
|
||||||
|
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BaseTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
|
||||||
|
{
|
||||||
|
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<int> BaseTransBufferManager::assignBufferIndexForRecv()
|
||||||
|
{
|
||||||
|
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BaseTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
|
||||||
|
{
|
||||||
|
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers(
|
||||||
|
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||||
|
runtime::BufferManager const& bufferManagerToUse)
|
||||||
|
{
|
||||||
|
return getOrAllocateBuffers(
|
||||||
|
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateRecvBuffers(
|
||||||
|
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||||
|
runtime::BufferManager const& bufferManagerToUse)
|
||||||
|
{
|
||||||
|
return getOrAllocateBuffers(
|
||||||
|
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource);
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime::ITensor::SharedPtr BaseTransBufferManager::getSendBuffer(std::optional<int> bufferId)
|
||||||
|
{
|
||||||
|
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||||
|
if (bufferId.has_value())
|
||||||
|
{
|
||||||
|
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mSendBufferCount);
|
||||||
|
return mConcurrenceSendResource.mBuffers[bufferId.value()];
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime::ITensor::SharedPtr BaseTransBufferManager::getRecvBuffer(std::optional<int> bufferId)
|
||||||
|
{
|
||||||
|
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||||
|
if (bufferId.has_value())
|
||||||
|
{
|
||||||
|
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mRecvBufferCount);
|
||||||
|
// TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1);
|
||||||
|
return mConcurrenceRecvResource.mBuffers[bufferId.value()];
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateBuffers(
|
||||||
|
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||||
|
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource)
|
||||||
|
{
|
||||||
|
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||||
|
TLLM_CHECK(requestedNumberOfElements.size() >= static_cast<size_t>(targetNum));
|
||||||
|
std::vector<runtime::ITensor::SharedPtr> retSplitCaches;
|
||||||
|
|
||||||
|
size_t bufferCoverTargetNum = 0;
|
||||||
|
|
||||||
|
if (bufferId.has_value())
|
||||||
|
{
|
||||||
|
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < concurrenceResource.mBuffers.size());
|
||||||
|
TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1);
|
||||||
|
size_t preBufferEleSize = 0;
|
||||||
|
for (int i = 0; i < targetNum; i++)
|
||||||
|
{
|
||||||
|
// Strict checking.
|
||||||
|
if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements)
|
||||||
|
{
|
||||||
|
auto slice = runtime::ITensor::slice(
|
||||||
|
concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]);
|
||||||
|
preBufferEleSize += requestedNumberOfElements[i];
|
||||||
|
bufferCoverTargetNum++;
|
||||||
|
retSplitCaches.push_back(std::move(slice));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
||||||
|
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
|
||||||
|
if (bufferCoverTargetNum < static_cast<size_t>(targetNum))
|
||||||
|
{
|
||||||
|
TLLM_LOG_WARNING(
|
||||||
|
"CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic "
|
||||||
|
"buffer which will fail with NIXL backend. It is recommended to set "
|
||||||
|
"cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config "
|
||||||
|
"YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance "
|
||||||
|
"may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, "
|
||||||
|
"mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld",
|
||||||
|
bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements,
|
||||||
|
requestedNumberOfElements[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i < targetNum; i++)
|
||||||
|
{
|
||||||
|
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
||||||
|
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
||||||
|
}
|
||||||
|
bufferCoverTargetNum = targetNum;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BaseTransBufferManager::allocateBuffer()
|
||||||
|
{
|
||||||
|
if (mOnlyUseDynamicBuffer)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType);
|
||||||
|
mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0);
|
||||||
|
mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0);
|
||||||
|
if (mUseFabricMemory)
|
||||||
|
{
|
||||||
|
mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount);
|
||||||
|
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||||
|
{
|
||||||
|
mFabricMemory.emplace_back(std::make_unique<kv_cache_manager::FabricMemory>(mTransferBufferSize));
|
||||||
|
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
||||||
|
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||||
|
{
|
||||||
|
mFabricMemory.emplace_back(std::make_unique<kv_cache_manager::FabricMemory>(mTransferBufferSize));
|
||||||
|
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
||||||
|
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (common::getEnvKVCacheTransferUseAsyncBuffer())
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||||
|
{
|
||||||
|
mConcurrenceSendResource.mBuffers[i]
|
||||||
|
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||||
|
{
|
||||||
|
mConcurrenceRecvResource.mBuffers[i]
|
||||||
|
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||||
|
}
|
||||||
|
mBufferManager.getStream().synchronize();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||||
|
{
|
||||||
|
mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
|
||||||
|
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||||
|
{
|
||||||
|
mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
|
||||||
|
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<int> BaseTransBufferManager::assignBufferIndex(
|
||||||
|
ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer)
|
||||||
|
{
|
||||||
|
if (onlyUseDynamicBuffer)
|
||||||
|
{
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
std::unique_lock lk(resource.mBuffersMutex);
|
||||||
|
resource.mBuffersCV.wait(
|
||||||
|
lk, [&resource, bufferCount]() { return static_cast<size_t>(resource.mConcurrence) < bufferCount; });
|
||||||
|
int bufferId = -1;
|
||||||
|
for (size_t i = 0; i < bufferCount; i++)
|
||||||
|
{
|
||||||
|
if (resource.mBufferIndexFlag[i] == 0)
|
||||||
|
{
|
||||||
|
bufferId = i;
|
||||||
|
resource.mBufferIndexFlag[bufferId] = 1;
|
||||||
|
resource.mConcurrence++;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast<size_t>(bufferId) < bufferCount,
|
||||||
|
" assignBufferIndex: Buffer index already assigned");
|
||||||
|
|
||||||
|
return bufferId;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BaseTransBufferManager::freeBufferIndex(
|
||||||
|
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer)
|
||||||
|
{
|
||||||
|
if (onlyUseDynamicBuffer)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (bufferId.has_value())
|
||||||
|
{
|
||||||
|
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
|
||||||
|
{
|
||||||
|
std::scoped_lock lk(resource.mBuffersMutex);
|
||||||
|
resource.mBufferIndexFlag[bufferId.value()] = 0;
|
||||||
|
}
|
||||||
|
resource.mConcurrence--;
|
||||||
|
resource.mBuffersCV.notify_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BaseTransBufferManager::getRecvBufferCount()
|
||||||
|
{
|
||||||
|
return mRecvBufferCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BaseTransBufferManager::getSendBufferCount()
|
||||||
|
{
|
||||||
|
return mSendBufferCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt_llm::batch_manager
|
||||||
144
cpp/tensorrt_llm/batch_manager/baseTransBuffer.h
Normal file
144
cpp/tensorrt_llm/batch_manager/baseTransBuffer.h
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||||
|
#include "tensorrt_llm/runtime/iTensor.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <optional>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||||
|
{
|
||||||
|
class FabricMemory;
|
||||||
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||||
|
|
||||||
|
namespace tensorrt_llm::batch_manager
|
||||||
|
{
|
||||||
|
|
||||||
|
/// @brief Base class for cache transfer buffer management.
|
||||||
|
/// Handles buffer pool allocation, index assignment, and slicing.
|
||||||
|
/// Derived classes provide cache-specific size calculations.
|
||||||
|
class BaseTransBufferManager
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~BaseTransBufferManager() = default;
|
||||||
|
|
||||||
|
/// @brief Assign a buffer index for sending.
|
||||||
|
/// @return Assigned buffer index, or nullopt if using dynamic buffers.
|
||||||
|
std::optional<int> assignBufferIndexForSend();
|
||||||
|
|
||||||
|
/// @brief Free a buffer index used for sending.
|
||||||
|
/// @param bufferId The buffer index to free.
|
||||||
|
void freeBufferIndexForSend(std::optional<int> bufferId);
|
||||||
|
|
||||||
|
/// @brief Assign a buffer index for receiving.
|
||||||
|
/// @return Assigned buffer index, or nullopt if using dynamic buffers.
|
||||||
|
std::optional<int> assignBufferIndexForRecv();
|
||||||
|
|
||||||
|
/// @brief Free a buffer index used for receiving.
|
||||||
|
/// @param bufferId The buffer index to free.
|
||||||
|
void freeBufferIndexForRecv(std::optional<int> bufferId);
|
||||||
|
|
||||||
|
/// @brief Get or allocate send buffers for cache transfer.
|
||||||
|
/// @param bufferId The assigned buffer ID.
|
||||||
|
/// @param targetNum Number of target sequences.
|
||||||
|
/// @param requestedNumberOfElements Sizes requested for each target.
|
||||||
|
/// @param bufferManagerToUse Buffer manager for dynamic allocation.
|
||||||
|
/// @return Tuple of (buffers, covered target count, is dynamic only).
|
||||||
|
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateSendBuffers(
|
||||||
|
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||||
|
runtime::BufferManager const& bufferManagerToUse);
|
||||||
|
|
||||||
|
/// @brief Get or allocate receive buffers for cache transfer.
|
||||||
|
/// @param bufferId The assigned buffer ID.
|
||||||
|
/// @param targetNum Number of target sequences.
|
||||||
|
/// @param requestedNumberOfElements Sizes requested for each target.
|
||||||
|
/// @param bufferManagerToUse Buffer manager for dynamic allocation.
|
||||||
|
/// @return Tuple of (buffers, covered target count, is dynamic only).
|
||||||
|
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateRecvBuffers(
|
||||||
|
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||||
|
runtime::BufferManager const& bufferManagerToUse);
|
||||||
|
|
||||||
|
/// @brief Get the send buffer for a given buffer ID.
|
||||||
|
runtime::ITensor::SharedPtr getSendBuffer(std::optional<int> bufferId);
|
||||||
|
|
||||||
|
/// @brief Get the receive buffer for a given buffer ID.
|
||||||
|
runtime::ITensor::SharedPtr getRecvBuffer(std::optional<int> bufferId);
|
||||||
|
|
||||||
|
/// @brief Get the number of receive buffers.
|
||||||
|
size_t getRecvBufferCount();
|
||||||
|
|
||||||
|
/// @brief Get the number of send buffers.
|
||||||
|
size_t getSendBufferCount();
|
||||||
|
|
||||||
|
/// @brief Get the maximum number of tokens configured.
|
||||||
|
std::optional<size_t> getMaxNumTokens()
|
||||||
|
{
|
||||||
|
return mMaxNumTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// @brief Constructor - derived classes call this after computing buffer sizes.
|
||||||
|
/// @param transferBufferSize Size of each transfer buffer in bytes.
|
||||||
|
/// @param dataType Data type for the buffers.
|
||||||
|
/// @param maxNumTokens Optional max tokens for sizing.
|
||||||
|
BaseTransBufferManager(
|
||||||
|
size_t transferBufferSize, nvinfer1::DataType dataType, std::optional<size_t> maxNumTokens = std::nullopt);
|
||||||
|
|
||||||
|
struct ConcurrenceResource
|
||||||
|
{
|
||||||
|
std::unordered_map<int, runtime::ITensor::SharedPtr> mBuffers;
|
||||||
|
std::vector<int> mBufferIndexFlag;
|
||||||
|
std::mutex mBuffersMutex;
|
||||||
|
std::condition_variable mBuffersCV;
|
||||||
|
std::atomic<int> mConcurrence{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateBuffers(std::optional<int> bufferId,
|
||||||
|
int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||||
|
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource);
|
||||||
|
|
||||||
|
void allocateBuffer();
|
||||||
|
std::optional<int> assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer);
|
||||||
|
void freeBufferIndex(
|
||||||
|
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer);
|
||||||
|
|
||||||
|
size_t mPreAllocBufferSize;
|
||||||
|
size_t mRecvBufferCount;
|
||||||
|
size_t mSendBufferCount;
|
||||||
|
size_t mTransferBufferSize;
|
||||||
|
bool mOnlyUseDynamicBuffer;
|
||||||
|
bool mUseFabricMemory;
|
||||||
|
size_t mNumberOfElements;
|
||||||
|
nvinfer1::DataType mDataType;
|
||||||
|
ConcurrenceResource mConcurrenceSendResource;
|
||||||
|
ConcurrenceResource mConcurrenceRecvResource;
|
||||||
|
runtime::BufferManager mBufferManager;
|
||||||
|
std::vector<std::unique_ptr<kv_cache_manager::FabricMemory>> mFabricMemory;
|
||||||
|
std::optional<size_t> mMaxNumTokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt_llm::batch_manager
|
||||||
@ -50,7 +50,8 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
|
|||||||
|
|
||||||
// Note: When recv side has CP, the requested seqLen is lesser than seqLen on the sender side as seqLen is
|
// Note: When recv side has CP, the requested seqLen is lesser than seqLen on the sender side as seqLen is
|
||||||
// distributed among CP ranks. So, we transfer all blocks from send side.
|
// distributed among CP ranks. So, we transfer all blocks from send side.
|
||||||
if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || lastBlockKey.uniqueTokens.size() == 0 || recvSideHasCP)
|
if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || !cacheManager->isEnablePartialReuse()
|
||||||
|
|| lastBlockKey.uniqueTokens.size() == 0 || recvSideHasCP)
|
||||||
{
|
{
|
||||||
// disable reuse path, and vwsa don't support reuse.
|
// disable reuse path, and vwsa don't support reuse.
|
||||||
bool needSendAllForWindow = common::getEnvKVCacheTransferAllBlocksForWindow();
|
bool needSendAllForWindow = common::getEnvKVCacheTransferAllBlocksForWindow();
|
||||||
@ -87,13 +88,13 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
|
|||||||
return BlockRange::fromReuseTree(*cacheManager, lastBlockKey, indexFromEnd);
|
return BlockRange::fromReuseTree(*cacheManager, lastBlockKey, indexFromEnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
BlockRange getBlockRangeForReceiving(
|
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
|
||||||
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse, bool recvSideHasCP)
|
bool srcEnableBlockReuse, bool srcEnablePartialReuse, bool recvSideHasCP)
|
||||||
{
|
{
|
||||||
// Note: When recv side has CP, we request all blocks from send side right now.
|
// Note: When recv side has CP, we request all blocks from send side right now.
|
||||||
auto poolNum = cacheManager->getBlockManager().getNumPools(
|
auto poolNum = cacheManager->getBlockManager().getNumPools(
|
||||||
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
|
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
|
||||||
if (poolNum == 1 && srcEnableBlockReuse && !recvSideHasCP)
|
if (poolNum == 1 && srcEnableBlockReuse && srcEnablePartialReuse && !recvSideHasCP)
|
||||||
{
|
{
|
||||||
// Build from all block ids, then slice off the reused blocks so we only transfer newly allocated ones.
|
// Build from all block ids, then slice off the reused blocks so we only transfer newly allocated ones.
|
||||||
auto windowSize = cacheManager->getBlockManager().getWindowSizesMetadata().begin()->first;
|
auto windowSize = cacheManager->getBlockManager().getWindowSizesMetadata().begin()->first;
|
||||||
@ -154,7 +155,8 @@ bool CacheFormatter::needSendCache(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
|
||||||
|
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
|
||||||
int selfTpRankInDpGroup = selfTpRank;
|
int selfTpRankInDpGroup = selfTpRank;
|
||||||
if (selfConfig.getParallelConfig().mEnableAttentionDP)
|
if (selfConfig.getParallelConfig().mEnableAttentionDP)
|
||||||
{
|
{
|
||||||
@ -554,7 +556,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
|
|||||||
auto const& destConfig = session.getOtherState().getCacheState().value();
|
auto const& destConfig = session.getOtherState().getCacheState().value();
|
||||||
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
|
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
|
||||||
auto& bufferManager = session.getBufferManager();
|
auto& bufferManager = session.getBufferManager();
|
||||||
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
|
auto blockRange = getBlockRangeForReceiving(
|
||||||
|
mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), destConfig.getEnablePartialReuse());
|
||||||
|
|
||||||
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,7 @@ using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager;
|
|||||||
using BlockRange = kv_cache_manager::BlockRange;
|
using BlockRange = kv_cache_manager::BlockRange;
|
||||||
|
|
||||||
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
|
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
|
||||||
bool srcEnableBlockReuse, bool recvSideHasCP = false);
|
bool srcEnableBlockReuse, bool srcEnablePartialReuse, bool recvSideHasCP = false);
|
||||||
|
|
||||||
// Used to support the cache transmission with different layouts and different protocols.
|
// Used to support the cache transmission with different layouts and different protocols.
|
||||||
class BaseCacheFormatter
|
class BaseCacheFormatter
|
||||||
|
|||||||
@ -20,12 +20,17 @@
|
|||||||
#include "tensorrt_llm/common/logger.h"
|
#include "tensorrt_llm/common/logger.h"
|
||||||
#include "tensorrt_llm/common/opUtils.h"
|
#include "tensorrt_llm/common/opUtils.h"
|
||||||
#include "tensorrt_llm/executor/executor.h"
|
#include "tensorrt_llm/executor/executor.h"
|
||||||
|
|
||||||
#include <NvInferRuntimeBase.h>
|
#include <NvInferRuntimeBase.h>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||||
{
|
{
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// FabricMemory Implementation
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
class FabricMemory::Impl
|
class FabricMemory::Impl
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -182,45 +187,46 @@ bool FabricMemory::supportFbaricMemory()
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
CacheTransBufferManager::CacheTransBufferManager(
|
// ============================================================================
|
||||||
|
// CacheTransBufferManager Implementation
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
size_t CacheTransBufferManager::computeTransferBufferSize(
|
||||||
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens, bool transferIndexerKCache)
|
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens, bool transferIndexerKCache)
|
||||||
: mCacheManager{cacheManager}
|
|
||||||
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
|
||||||
, mMaxNumTokens{maxNumTokens}
|
|
||||||
{
|
{
|
||||||
// TODO: FP4 dataSize
|
nvinfer1::DataType dataType;
|
||||||
TLLM_CHECK(mCacheManager);
|
|
||||||
if (transferIndexerKCache)
|
if (transferIndexerKCache)
|
||||||
{
|
{
|
||||||
mDataType = mCacheManager->getIndexerKCachePool()->getDataType();
|
dataType = cacheManager->getIndexerKCachePool()->getDataType();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
mDataType = mCacheManager->getPrimaryPool(0)->getDataType();
|
dataType = cacheManager->getPrimaryPool(0)->getDataType();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tokensPerBlock = mCacheManager->getBlockManager().getTokensPerBlock();
|
auto tokensPerBlock = cacheManager->getBlockManager().getTokensPerBlock();
|
||||||
size_t bufferSizeFromMaxNumToken = 0;
|
size_t bufferSizeFromMaxNumToken = 0;
|
||||||
|
|
||||||
if (maxNumTokens.has_value())
|
if (maxNumTokens.has_value())
|
||||||
{
|
{
|
||||||
TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0);
|
TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0);
|
||||||
auto dataSize = common::getDTypeSize(mDataType);
|
auto dataSize = common::getDTypeSize(dataType);
|
||||||
SizeType32 kvCacheByteSizePerTokenPerLayer = 0;
|
SizeType32 kvCacheByteSizePerTokenPerLayer = 0;
|
||||||
if (transferIndexerKCache)
|
if (transferIndexerKCache)
|
||||||
{
|
{
|
||||||
kvCacheByteSizePerTokenPerLayer
|
kvCacheByteSizePerTokenPerLayer
|
||||||
= mCacheManager->getIndexerKCachePool()->getDimension<-1>() * dataSize / tokensPerBlock;
|
= cacheManager->getIndexerKCachePool()->getDimension<-1>() * dataSize / tokensPerBlock;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
auto primaryPool = mCacheManager->getPrimaryPool(0);
|
auto primaryPool = cacheManager->getPrimaryPool(0);
|
||||||
kvCacheByteSizePerTokenPerLayer
|
kvCacheByteSizePerTokenPerLayer
|
||||||
= primaryPool->getDimension<-1>() * primaryPool->getDimension<2>() * dataSize / tokensPerBlock;
|
= primaryPool->getDimension<-1>() * primaryPool->getDimension<2>() * dataSize / tokensPerBlock;
|
||||||
}
|
}
|
||||||
for (auto layerId = 0; layerId < mCacheManager->getBlockManager().getNumLayers(); layerId++)
|
for (auto layerId = 0; layerId < cacheManager->getBlockManager().getNumLayers(); layerId++)
|
||||||
{
|
{
|
||||||
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
|
auto poolIdx = cacheManager->getBlockManager().getLayerPoolIdx(layerId);
|
||||||
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
|
auto windowSize = static_cast<size_t>(cacheManager->getBlockManager().getPoolWindowSize(poolIdx));
|
||||||
auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock;
|
auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock;
|
||||||
auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value());
|
auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value());
|
||||||
if (common::getEnvKVCacheTransferAllBlocksForWindow())
|
if (common::getEnvKVCacheTransferAllBlocksForWindow())
|
||||||
@ -233,26 +239,20 @@ CacheTransBufferManager::CacheTransBufferManager(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mTransferBufferSize
|
return maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
|
||||||
= maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
|
}
|
||||||
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
|
|
||||||
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
|
|
||||||
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
|
|
||||||
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
|
|
||||||
&& FabricMemory::supportFbaricMemory();
|
|
||||||
if (mUseFabricMemory)
|
|
||||||
{
|
|
||||||
mTransferBufferSize = FabricMemory::getAlignedSize(mTransferBufferSize);
|
|
||||||
}
|
|
||||||
mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount);
|
|
||||||
TLLM_LOG_INFO(
|
|
||||||
"CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
|
|
||||||
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d "
|
|
||||||
"mUseFabricMemory:%d mDataType:%d",
|
|
||||||
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
|
|
||||||
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType);
|
|
||||||
|
|
||||||
allocateBuffer();
|
CacheTransBufferManager::CacheTransBufferManager(
|
||||||
|
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens, bool transferIndexerKCache)
|
||||||
|
: BaseTransBufferManager(computeTransferBufferSize(cacheManager, maxNumTokens, transferIndexerKCache),
|
||||||
|
transferIndexerKCache ? cacheManager->getIndexerKCachePool()->getDataType()
|
||||||
|
: cacheManager->getPrimaryPool(0)->getDataType(),
|
||||||
|
maxNumTokens)
|
||||||
|
, mCacheManager{cacheManager}
|
||||||
|
{
|
||||||
|
// TODO: FP4 dataSize
|
||||||
|
TLLM_CHECK(mCacheManager);
|
||||||
|
TLLM_LOG_INFO("CacheTransBufferManager created for KV cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CacheTransBufferManager::preAllocBufferSize(
|
size_t CacheTransBufferManager::preAllocBufferSize(
|
||||||
@ -298,233 +298,4 @@ size_t CacheTransBufferManager::preAllocBufferSize(
|
|||||||
return preAllocBufferSize;
|
return preAllocBufferSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<int> CacheTransBufferManager::assignBufferIndexForSend()
|
|
||||||
{
|
|
||||||
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CacheTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
|
|
||||||
{
|
|
||||||
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<int> CacheTransBufferManager::assignBufferIndexForRecv()
|
|
||||||
{
|
|
||||||
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CacheTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
|
|
||||||
{
|
|
||||||
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBufferManager::getOrAllocateSendBuffers(
|
|
||||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
|
||||||
runtime::BufferManager const& bufferManagerToUse)
|
|
||||||
{
|
|
||||||
return getOrAllocateBuffers(
|
|
||||||
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBufferManager::getOrAllocateRecvBuffers(
|
|
||||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
|
||||||
runtime::BufferManager const& bufferManagerToUse)
|
|
||||||
{
|
|
||||||
return getOrAllocateBuffers(
|
|
||||||
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource);
|
|
||||||
}
|
|
||||||
|
|
||||||
runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional<int> bufferId)
|
|
||||||
{
|
|
||||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
|
||||||
if (bufferId.has_value())
|
|
||||||
{
|
|
||||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mSendBufferCount);
|
|
||||||
return mConcurrenceSendResource.mBuffers[bufferId.value()];
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
runtime::ITensor::SharedPtr CacheTransBufferManager::getRecvBuffer(std::optional<int> bufferId)
|
|
||||||
{
|
|
||||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
|
||||||
if (bufferId.has_value())
|
|
||||||
{
|
|
||||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mRecvBufferCount);
|
|
||||||
// TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1);
|
|
||||||
return mConcurrenceRecvResource.mBuffers[bufferId.value()];
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBufferManager::getOrAllocateBuffers(
|
|
||||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
|
||||||
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource)
|
|
||||||
{
|
|
||||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
|
||||||
TLLM_CHECK(requestedNumberOfElements.size() >= static_cast<size_t>(targetNum));
|
|
||||||
std::vector<runtime::ITensor::SharedPtr> retSplitCaches;
|
|
||||||
|
|
||||||
size_t bufferCoverTargetNum = 0;
|
|
||||||
|
|
||||||
if (bufferId.has_value())
|
|
||||||
{
|
|
||||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < concurrenceResource.mBuffers.size());
|
|
||||||
TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1);
|
|
||||||
size_t preBufferEleSize = 0;
|
|
||||||
for (int i = 0; i < targetNum; i++)
|
|
||||||
{
|
|
||||||
// Strict checking.
|
|
||||||
if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements)
|
|
||||||
{
|
|
||||||
auto slice = runtime::ITensor::slice(
|
|
||||||
concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]);
|
|
||||||
preBufferEleSize += requestedNumberOfElements[i];
|
|
||||||
bufferCoverTargetNum++;
|
|
||||||
retSplitCaches.push_back(std::move(slice));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
|
||||||
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
|
|
||||||
if (bufferCoverTargetNum < static_cast<size_t>(targetNum))
|
|
||||||
{
|
|
||||||
TLLM_LOG_WARNING(
|
|
||||||
"CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic "
|
|
||||||
"buffer which will fail with NIXL backend. It is recommended to set "
|
|
||||||
"cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config "
|
|
||||||
"YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance "
|
|
||||||
"may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, "
|
|
||||||
"mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld",
|
|
||||||
bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements,
|
|
||||||
requestedNumberOfElements[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
for (int i = 0; i < targetNum; i++)
|
|
||||||
{
|
|
||||||
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
|
||||||
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
|
||||||
}
|
|
||||||
bufferCoverTargetNum = targetNum;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CacheTransBufferManager::allocateBuffer()
|
|
||||||
{
|
|
||||||
if (mOnlyUseDynamicBuffer)
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType);
|
|
||||||
mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0);
|
|
||||||
mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0);
|
|
||||||
if (mUseFabricMemory)
|
|
||||||
{
|
|
||||||
mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount);
|
|
||||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
|
||||||
{
|
|
||||||
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
|
|
||||||
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
|
||||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
|
||||||
{
|
|
||||||
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
|
|
||||||
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
|
||||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (common::getEnvKVCacheTransferUseAsyncBuffer())
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
|
||||||
{
|
|
||||||
mConcurrenceSendResource.mBuffers[i]
|
|
||||||
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
|
||||||
{
|
|
||||||
mConcurrenceRecvResource.mBuffers[i]
|
|
||||||
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
|
||||||
}
|
|
||||||
mBufferManager.getStream().synchronize();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
|
||||||
{
|
|
||||||
mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
|
|
||||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
|
||||||
{
|
|
||||||
mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
|
|
||||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<int> CacheTransBufferManager::assignBufferIndex(
|
|
||||||
ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer)
|
|
||||||
{
|
|
||||||
if (onlyUseDynamicBuffer)
|
|
||||||
{
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
std::unique_lock lk(resource.mBuffersMutex);
|
|
||||||
resource.mBuffersCV.wait(
|
|
||||||
lk, [&resource, bufferCount]() { return static_cast<size_t>(resource.mConcurrence) < bufferCount; });
|
|
||||||
int bufferId = -1;
|
|
||||||
for (size_t i = 0; i < bufferCount; i++)
|
|
||||||
{
|
|
||||||
if (resource.mBufferIndexFlag[i] == 0)
|
|
||||||
{
|
|
||||||
bufferId = i;
|
|
||||||
resource.mBufferIndexFlag[bufferId] = 1;
|
|
||||||
resource.mConcurrence++;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast<size_t>(bufferId) < bufferCount,
|
|
||||||
" assignBufferIndex: Buffer index already assigned");
|
|
||||||
|
|
||||||
return bufferId;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CacheTransBufferManager::freeBufferIndex(
|
|
||||||
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer)
|
|
||||||
{
|
|
||||||
if (onlyUseDynamicBuffer)
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (bufferId.has_value())
|
|
||||||
{
|
|
||||||
|
|
||||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
|
|
||||||
{
|
|
||||||
std::scoped_lock lk(resource.mBuffersMutex);
|
|
||||||
resource.mBufferIndexFlag[bufferId.value()] = 0;
|
|
||||||
}
|
|
||||||
resource.mConcurrence--;
|
|
||||||
resource.mBuffersCV.notify_one();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CacheTransBufferManager::getRecvBufferCount()
|
|
||||||
{
|
|
||||||
return mRecvBufferCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CacheTransBufferManager::getSendBufferCount()
|
|
||||||
{
|
|
||||||
return mSendBufferCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||||
|
|||||||
@ -17,13 +17,16 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "tensorrt_llm/batch_manager/baseTransBuffer.h"
|
||||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||||
#include "tensorrt_llm/executor/executor.h"
|
#include "tensorrt_llm/executor/executor.h"
|
||||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||||
#include "tensorrt_llm/runtime/iTensor.h"
|
#include "tensorrt_llm/runtime/iTensor.h"
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <map>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -54,7 +57,9 @@ private:
|
|||||||
std::unique_ptr<Impl> pImpl;
|
std::unique_ptr<Impl> pImpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CacheTransBufferManager
|
/// @brief KV Cache specific transfer buffer manager.
|
||||||
|
/// Inherits common buffer management from BaseTransBufferManager.
|
||||||
|
class CacheTransBufferManager : public BaseTransBufferManager
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CacheTransBufferManager(KVCacheManager::BaseKVCacheManager* cacheManager,
|
CacheTransBufferManager(KVCacheManager::BaseKVCacheManager* cacheManager,
|
||||||
@ -64,62 +69,18 @@ public:
|
|||||||
SizeType32 tokensPerBlock,
|
SizeType32 tokensPerBlock,
|
||||||
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig = std::nullopt);
|
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig = std::nullopt);
|
||||||
|
|
||||||
std::optional<int> assignBufferIndexForSend();
|
/// @brief Get the KV cache manager.
|
||||||
void freeBufferIndexForSend(std::optional<int> bufferId);
|
[[nodiscard]] KVCacheManager::BaseKVCacheManager* getCacheManager() const noexcept
|
||||||
std::optional<int> assignBufferIndexForRecv();
|
|
||||||
void freeBufferIndexForRecv(std::optional<int> bufferId);
|
|
||||||
|
|
||||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateSendBuffers(
|
|
||||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
|
||||||
runtime::BufferManager const& bufferManagerToUse);
|
|
||||||
|
|
||||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateRecvBuffers(
|
|
||||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
|
||||||
runtime::BufferManager const& bufferManagerToUse);
|
|
||||||
|
|
||||||
runtime::ITensor::SharedPtr getSendBuffer(std::optional<int> bufferId);
|
|
||||||
runtime::ITensor::SharedPtr getRecvBuffer(std::optional<int> bufferId);
|
|
||||||
size_t getRecvBufferCount();
|
|
||||||
size_t getSendBufferCount();
|
|
||||||
|
|
||||||
std::optional<size_t> getMaxNumTokens()
|
|
||||||
{
|
{
|
||||||
return mMaxNumTokens;
|
return mCacheManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct ConcurrenceResource
|
/// @brief Compute transfer buffer size from KV cache configuration.
|
||||||
{
|
static size_t computeTransferBufferSize(KVCacheManager::BaseKVCacheManager* cacheManager,
|
||||||
std::unordered_map<int, runtime::ITensor::SharedPtr> mBuffers;
|
std::optional<size_t> maxNumTokens, bool transferIndexerKCache);
|
||||||
std::vector<int> mBufferIndexFlag;
|
|
||||||
std::mutex mBuffersMutex;
|
|
||||||
std::condition_variable mBuffersCV;
|
|
||||||
std::atomic<int> mConcurrence = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateBuffers(std::optional<int> bufferId,
|
|
||||||
int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
|
||||||
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource);
|
|
||||||
|
|
||||||
void allocateBuffer();
|
|
||||||
std::optional<int> assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer);
|
|
||||||
void freeBufferIndex(
|
|
||||||
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer);
|
|
||||||
|
|
||||||
size_t mPreAllocBufferSize;
|
|
||||||
size_t mRecvBufferCount;
|
|
||||||
size_t mSendBufferCount;
|
|
||||||
size_t mTransferBufferSize;
|
|
||||||
bool mOnlyUseDynamicBuffer;
|
|
||||||
bool mUseFabricMemory;
|
|
||||||
size_t mNumberOfElements;
|
|
||||||
nvinfer1::DataType mDataType;
|
|
||||||
ConcurrenceResource mConcurrenceSendResource;
|
|
||||||
ConcurrenceResource mConcurrenceRecvResource;
|
|
||||||
KVCacheManager::BaseKVCacheManager* mCacheManager;
|
KVCacheManager::BaseKVCacheManager* mCacheManager;
|
||||||
runtime::BufferManager mBufferManager;
|
|
||||||
std::vector<std::unique_ptr<FabricMemory>> mFabricMemory;
|
|
||||||
std::optional<size_t> mMaxNumTokens;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||||
|
|||||||
@ -81,6 +81,11 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
|
|||||||
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
|
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
|
||||||
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
|
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
|
||||||
}
|
}
|
||||||
|
else if (common::getEnvUseMooncakeKvCache())
|
||||||
|
{
|
||||||
|
backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE;
|
||||||
|
TLLM_LOG_INFO("Enable MOONCAKE KV cache transport.");
|
||||||
|
}
|
||||||
else if (common::getEnvUseMPIKvCache())
|
else if (common::getEnvUseMPIKvCache())
|
||||||
{
|
{
|
||||||
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
|
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
|
||||||
@ -126,10 +131,10 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
|||||||
mGroupComm = std::make_shared<CacheTransceiverComm>(tensorrt_llm::pg_utils::get_world_pg());
|
mGroupComm = std::make_shared<CacheTransceiverComm>(tensorrt_llm::pg_utils::get_world_pg());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (worldConfig.isTensorParallel())
|
if (worldConfig.isTensorParallel() || worldConfig.isContextParallel())
|
||||||
{
|
{
|
||||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getRank()));
|
||||||
}
|
}
|
||||||
int kvFactor = 2;
|
int kvFactor = 2;
|
||||||
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
||||||
@ -138,25 +143,24 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
|||||||
}
|
}
|
||||||
mCacheState = std::make_unique<executor::kv_cache::CacheState>(cacheStateModelCfg, worldConfig,
|
mCacheState = std::make_unique<executor::kv_cache::CacheState>(cacheStateModelCfg, worldConfig,
|
||||||
attentionLayerNumPerPP, dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse(),
|
attentionLayerNumPerPP, dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse(),
|
||||||
cacheManager->isEnableIndexerKCache(), cacheManager->getIndexerKCacheIndexHeadDim(),
|
cacheManager->isEnablePartialReuse(), cacheManager->isEnableIndexerKCache(),
|
||||||
cacheManager->getIndexerKCacheQuantBlockSize());
|
cacheManager->getIndexerKCacheIndexHeadDim(), cacheManager->getIndexerKCacheQuantBlockSize());
|
||||||
|
|
||||||
if (mCacheState->getParallelConfig().mEnableAttentionDP)
|
if (mCacheState->getParallelConfig().mEnableAttentionDP)
|
||||||
{
|
{
|
||||||
int TPSizeInDPGroup
|
int dpSize = mCacheState->getParallelConfig().mDPsize;
|
||||||
= mCacheState->getParallelConfig().mTensorParallelism / mCacheState->getParallelConfig().mDPsize;
|
|
||||||
int DPSize = mCacheState->getParallelConfig().mDPsize;
|
|
||||||
int TPRankInDPGroup = worldConfig.getTensorParallelRank() % TPSizeInDPGroup;
|
|
||||||
|
|
||||||
int DPRank = (worldConfig.getRank() - TPSizeInDPGroup * DPSize * worldConfig.getPipelineParallelRank()
|
// dpRank is derived from the tensor parallel rank, which already accounts for CP.
|
||||||
- TPRankInDPGroup)
|
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
|
||||||
/ TPSizeInDPGroup;
|
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
|
||||||
// <PP,DP,TP>
|
int dpRank = mCacheState->getParallelConfig().mDPrank;
|
||||||
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
|
// <PP,DP,TP,CP>
|
||||||
if (worldConfig.isTensorParallel())
|
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(dpRank, worldConfig.getRank()));
|
||||||
|
if (worldConfig.isTensorParallel() || worldConfig.isContextParallel())
|
||||||
{
|
{
|
||||||
|
// Group ranks with same (ppRank, dpRank) accounting for CP.
|
||||||
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
|
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
|
||||||
mGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
|
mGroupComm->split(worldConfig.getPipelineParallelRank() * dpSize + dpRank, worldConfig.getRank()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
|
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
|
||||||
@ -203,9 +207,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
|||||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
|
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
|
||||||
{
|
{
|
||||||
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
|
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
|
||||||
mCacheTransBufferManagerPtrs, *mCacheState);
|
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
|
||||||
TLLM_LOG_INFO("NIXL Connection Manager created");
|
TLLM_LOG_INFO("NIXL Connection Manager created");
|
||||||
}
|
}
|
||||||
|
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
|
||||||
|
{
|
||||||
|
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
|
||||||
|
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
|
||||||
|
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
|
||||||
|
}
|
||||||
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
|
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
|
||||||
{
|
{
|
||||||
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
|
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
|
||||||
@ -416,7 +426,8 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
|
RequestStatuses CacheTransceiver::checkContextTransferStatus(
|
||||||
|
std::optional<int> const& atLeastRequestNum, bool markComplete)
|
||||||
{
|
{
|
||||||
bool blockAll = !atLeastRequestNum.has_value();
|
bool blockAll = !atLeastRequestNum.has_value();
|
||||||
std::optional<int> senderFutureTimeoutMs = std::nullopt;
|
std::optional<int> senderFutureTimeoutMs = std::nullopt;
|
||||||
@ -475,6 +486,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
|||||||
toCompleteIdSet.insert(request->mRequestId);
|
toCompleteIdSet.insert(request->mRequestId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RequestStatuses requestsStatus{};
|
||||||
|
|
||||||
// Complete all the requests in toCompleteIdSet
|
// Complete all the requests in toCompleteIdSet
|
||||||
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
|
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
|
||||||
{
|
{
|
||||||
@ -488,7 +501,11 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
|||||||
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
|
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
|
||||||
{
|
{
|
||||||
future.get();
|
future.get();
|
||||||
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
|
requestsStatus.completedRequestIds.insert(request->mRequestId);
|
||||||
|
if (markComplete)
|
||||||
|
{
|
||||||
|
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
|
||||||
|
}
|
||||||
it = mSenderFutures.erase(it);
|
it = mSenderFutures.erase(it);
|
||||||
}
|
}
|
||||||
else if (status == std::future_status::timeout)
|
else if (status == std::future_status::timeout)
|
||||||
@ -503,6 +520,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
|||||||
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);
|
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);
|
||||||
|
|
||||||
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
||||||
|
requestsStatus.errorRequestIds.insert(request->mRequestId);
|
||||||
it = mSenderFutures.erase(it);
|
it = mSenderFutures.erase(it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -511,6 +529,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
|||||||
TLLM_LOG_ERROR(
|
TLLM_LOG_ERROR(
|
||||||
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
|
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
|
||||||
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
||||||
|
requestsStatus.errorRequestIds.insert(request->mRequestId);
|
||||||
it = mSenderFutures.erase(it);
|
it = mSenderFutures.erase(it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -519,6 +538,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
|||||||
++it;
|
++it;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return requestsStatus;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
|
void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
|
||||||
|
|||||||
@ -358,8 +358,9 @@ public:
|
|||||||
|
|
||||||
TransceiverTag::Id id;
|
TransceiverTag::Id id;
|
||||||
RequestInfo info;
|
RequestInfo info;
|
||||||
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
|
auto const* connection = isAgent
|
||||||
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
|
? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate)
|
||||||
|
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id));
|
||||||
if (connection == nullptr && !mManager->isRunning())
|
if (connection == nullptr && !mManager->isRunning())
|
||||||
{
|
{
|
||||||
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
|
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
|
||||||
@ -395,8 +396,8 @@ public:
|
|||||||
if (it == mRequestToSession.end())
|
if (it == mRequestToSession.end())
|
||||||
{
|
{
|
||||||
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
||||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
|
DataContext{tagFromRequestId(requestId), mTerminate}, mSelfState, info.getTransState(),
|
||||||
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
|
mBufferManager, info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
|
||||||
!common::getEnvKVCacheTimeOutputPath().empty());
|
!common::getEnvKVCacheTimeOutputPath().empty());
|
||||||
session.setTime(TransferSession::kTimeRequestInfo);
|
session.setTime(TransferSession::kTimeRequestInfo);
|
||||||
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
||||||
@ -685,6 +686,10 @@ private:
|
|||||||
{
|
{
|
||||||
future.get();
|
future.get();
|
||||||
}
|
}
|
||||||
|
if (mResponseFuture.valid())
|
||||||
|
{
|
||||||
|
mResponseFuture.get();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void removeResponse(std::map<RequestIdType, Response>::iterator it)
|
void removeResponse(std::map<RequestIdType, Response>::iterator it)
|
||||||
@ -820,8 +825,8 @@ public:
|
|||||||
{
|
{
|
||||||
auto* cacheManager = mFormatter->getCacheManager();
|
auto* cacheManager = mFormatter->getCacheManager();
|
||||||
auto beam = 0;
|
auto beam = 0;
|
||||||
auto requestedBlockRange
|
auto requestedBlockRange = getBlockRangeForReceiving(
|
||||||
= getBlockRangeForReceiving(cacheManager, llmRequest, destCacheState.getEnableBlockReuse());
|
cacheManager, llmRequest, destCacheState.getEnableBlockReuse(), destCacheState.getEnablePartialReuse());
|
||||||
|
|
||||||
auto const& uniqueTokens = llmRequest.getUniqueTokens(beam);
|
auto const& uniqueTokens = llmRequest.getUniqueTokens(beam);
|
||||||
auto lastBlockKey
|
auto lastBlockKey
|
||||||
@ -886,9 +891,9 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto const& resource = getReceiveCacheResource(llmRequest);
|
auto const& resource = getReceiveCacheResource(llmRequest);
|
||||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId), mTerminate},
|
||||||
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
|
mSelfState, contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(),
|
||||||
&llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
|
requestInfo.getLastBlockKey(), &llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
|
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
|
||||||
@ -964,7 +969,7 @@ public:
|
|||||||
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
|
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
|
||||||
TLLM_CHECK(agentConnection);
|
TLLM_CHECK(agentConnection);
|
||||||
isReady = agentConnection->recvReadySignal(
|
isReady = agentConnection->recvReadySignal(
|
||||||
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
|
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG, mTerminate});
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -979,6 +984,7 @@ public:
|
|||||||
|
|
||||||
~Impl()
|
~Impl()
|
||||||
{
|
{
|
||||||
|
mTerminate.store(true);
|
||||||
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
|
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
|
||||||
{
|
{
|
||||||
asyncResource->mTerminate = true;
|
asyncResource->mTerminate = true;
|
||||||
@ -1134,6 +1140,7 @@ private:
|
|||||||
runtime::BufferManager mBufferManager;
|
runtime::BufferManager mBufferManager;
|
||||||
std::ofstream mMeasuresFile;
|
std::ofstream mMeasuresFile;
|
||||||
std::mutex mMeasuresFileMutex;
|
std::mutex mMeasuresFileMutex;
|
||||||
|
std::atomic<bool> mTerminate{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
void CacheSender::ImplDeleter::operator()(Impl* ptr)
|
void CacheSender::ImplDeleter::operator()(Impl* ptr)
|
||||||
|
|||||||
@ -34,6 +34,7 @@
|
|||||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -99,6 +100,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
|
|||||||
auto const multimodalHashes = llmRequest.getMultimodalHashes();
|
auto const multimodalHashes = llmRequest.getMultimodalHashes();
|
||||||
auto const multimodalPositions = llmRequest.getMultimodalPositions();
|
auto const multimodalPositions = llmRequest.getMultimodalPositions();
|
||||||
auto const multimodalLengths = llmRequest.getMultimodalLengths();
|
auto const multimodalLengths = llmRequest.getMultimodalLengths();
|
||||||
|
auto const multimodalUuids = llmRequest.getMultimodalUuids();
|
||||||
|
|
||||||
if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes)
|
if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes)
|
||||||
|| (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty()
|
|| (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty()
|
||||||
@ -114,7 +116,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MmKey> extraKeys; // MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
|
std::vector<MmKey> extraKeys;
|
||||||
extraKeys.reserve((*multimodalPositions)->size());
|
extraKeys.reserve((*multimodalPositions)->size());
|
||||||
std::array<uint8_t, 32> mmHashArray;
|
std::array<uint8_t, 32> mmHashArray;
|
||||||
|
|
||||||
@ -144,7 +146,15 @@ std::vector<MmKey> generateBlockHashExtraKeys(
|
|||||||
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
|
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
|
||||||
{
|
{
|
||||||
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
|
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
|
||||||
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
|
|
||||||
|
// Get UUID if available
|
||||||
|
std::optional<std::string> uuid = std::nullopt;
|
||||||
|
if (multimodalUuids && *multimodalUuids && i < (*multimodalUuids)->size())
|
||||||
|
{
|
||||||
|
uuid = (*(*multimodalUuids))[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
extraKeys.emplace_back(mmHashArray, mmStartInBlock, std::move(uuid));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,8 +231,10 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
|
|||||||
// block
|
// block
|
||||||
if (!blockKey.extraKeys.empty())
|
if (!blockKey.extraKeys.empty())
|
||||||
{
|
{
|
||||||
for (auto const& [mmHash, startOffset] : blockKey.extraKeys)
|
for (auto const& mmKey : blockKey.extraKeys)
|
||||||
{
|
{
|
||||||
|
auto const& mmHash = mmKey.hash;
|
||||||
|
auto const& startOffset = mmKey.startOffset;
|
||||||
// Hash the multimodal hash array in 32-bit chunks (more efficient)
|
// Hash the multimodal hash array in 32-bit chunks (more efficient)
|
||||||
for (size_t i = 0; i < 32; i += 4)
|
for (size_t i = 0; i < 32; i += 4)
|
||||||
{
|
{
|
||||||
@ -416,6 +428,7 @@ void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock)
|
|||||||
|
|
||||||
void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block)
|
void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block)
|
||||||
{
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mNextBlocksMutex);
|
||||||
if (mNextBlocks.find(blockKey) == mNextBlocks.end())
|
if (mNextBlocks.find(blockKey) == mNextBlocks.end())
|
||||||
{
|
{
|
||||||
mNextBlocks[blockKey] = std::move(block);
|
mNextBlocks[blockKey] = std::move(block);
|
||||||
@ -425,6 +438,8 @@ void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block)
|
|||||||
std::tuple<bool, SizeType32, BlockPtr> KVCacheBlock::findMatchingBlock(
|
std::tuple<bool, SizeType32, BlockPtr> KVCacheBlock::findMatchingBlock(
|
||||||
BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const
|
BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const
|
||||||
{
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mNextBlocksMutex);
|
||||||
|
|
||||||
if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0)
|
if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0)
|
||||||
{
|
{
|
||||||
return {false, 0, nullptr};
|
return {false, 0, nullptr};
|
||||||
@ -474,9 +489,36 @@ void KVCacheBlock::freeLeafBlock()
|
|||||||
|
|
||||||
void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
|
void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
|
||||||
{
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mNextBlocksMutex);
|
||||||
mNextBlocks.erase(blockKey);
|
mNextBlocks.erase(blockKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KVCacheBlock::freeDescendantsRecursively()
|
||||||
|
{
|
||||||
|
bool hasChildren = !mNextBlocks.empty();
|
||||||
|
if (hasChildren)
|
||||||
|
{
|
||||||
|
for (auto it = mNextBlocks.begin(); it != mNextBlocks.end();)
|
||||||
|
{
|
||||||
|
it->second->freeDescendantsRecursively();
|
||||||
|
TLLM_LOG_DEBUG("KVCacheBlock::freeDescendantsRecursively - Freeing block %d", it->second->getBlockId());
|
||||||
|
it = mNextBlocks.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mPrevBlock = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KVCacheBlock::freeBlockAndAllDescendants()
|
||||||
|
{
|
||||||
|
// free from previous block
|
||||||
|
if (mPrevBlock != nullptr)
|
||||||
|
{
|
||||||
|
mPrevBlock->removeNextBlock(mBlockKey);
|
||||||
|
mPrevBlock = nullptr;
|
||||||
|
}
|
||||||
|
freeDescendantsRecursively();
|
||||||
|
}
|
||||||
|
|
||||||
bool KVCacheBlock::isFull() const
|
bool KVCacheBlock::isFull() const
|
||||||
{
|
{
|
||||||
return mIsFull;
|
return mIsFull;
|
||||||
@ -956,19 +998,14 @@ void WindowBlockManager::freeLeafBlock(BlockPtr const& block)
|
|||||||
|
|
||||||
void WindowBlockManager::freeChildren(BlockPtr const& block)
|
void WindowBlockManager::freeChildren(BlockPtr const& block)
|
||||||
{
|
{
|
||||||
// Free all descendants of block
|
// Tell event manager we are freeing block
|
||||||
for (auto const& p : block->getNextBlocks())
|
|
||||||
{
|
|
||||||
auto childBlock = p.second;
|
|
||||||
freeChildren(childBlock);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free block
|
|
||||||
if (mEventManager && blockInRadixTree(block))
|
if (mEventManager && blockInRadixTree(block))
|
||||||
{
|
{
|
||||||
mEventManager->enqueueRemovedEvent(block, mWindowSize);
|
mEventManager->enqueueRemovedEvent(block, mWindowSize);
|
||||||
}
|
}
|
||||||
freeLeafBlock(block);
|
|
||||||
|
// Free block and all it's descendants from radix tree
|
||||||
|
block->freeBlockAndAllDescendants();
|
||||||
}
|
}
|
||||||
|
|
||||||
BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority,
|
BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority,
|
||||||
@ -1155,6 +1192,7 @@ std::optional<BlockKey> WindowBlockManager::findNewContextBlock(
|
|||||||
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
|
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
|
||||||
BlockKey ret;
|
BlockKey ret;
|
||||||
ret.loraTaskId = llmRequest.getLoraTaskId();
|
ret.loraTaskId = llmRequest.getLoraTaskId();
|
||||||
|
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
|
||||||
auto searchRoot = mCachedBlocksRoot;
|
auto searchRoot = mCachedBlocksRoot;
|
||||||
for (auto const& blockKey : blockKeys)
|
for (auto const& blockKey : blockKeys)
|
||||||
{
|
{
|
||||||
@ -1224,7 +1262,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
|||||||
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end()
|
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end()
|
||||||
? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse)
|
? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse)
|
||||||
: std::make_tuple(false, 0, nullptr);
|
: std::make_tuple(false, 0, nullptr);
|
||||||
if (matchingBlock != nullptr)
|
if (matchingBlock != nullptr && numMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen())
|
||||||
{
|
{
|
||||||
KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId();
|
KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId();
|
||||||
|
|
||||||
@ -1338,6 +1376,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens);
|
||||||
return numMatchedTokens;
|
return numMatchedTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1370,15 +1409,16 @@ void WindowBlockManager::refreshBlocks()
|
|||||||
|
|
||||||
// There are two versions of BlockManager::addSequence function.
|
// There are two versions of BlockManager::addSequence function.
|
||||||
// This is called when block reuse is enabled.
|
// This is called when block reuse is enabled.
|
||||||
void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
|
SizeType32 BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
|
||||||
LlmRequest& llmRequest, SizeType32 windowSize)
|
LlmRequest& llmRequest, SizeType32 windowSize)
|
||||||
{
|
{
|
||||||
mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
|
return mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
// There are two versions of WindowBlockManager::addSequence function.
|
// There are two versions of WindowBlockManager::addSequence function.
|
||||||
// This is called when block reuse is enabled.
|
// This is called when block reuse is enabled.
|
||||||
void WindowBlockManager::addSequence(
|
// Returns the total prepopulatedPromptLen (including connector matched tokens) for this window.
|
||||||
|
SizeType32 WindowBlockManager::addSequence(
|
||||||
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest)
|
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest)
|
||||||
{
|
{
|
||||||
auto const requestId = sequence.getRequestId();
|
auto const requestId = sequence.getRequestId();
|
||||||
@ -1430,9 +1470,13 @@ void WindowBlockManager::addSequence(
|
|||||||
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
|
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
|
// Return the total prepopulated length for this window (do not set on llmRequest here -
|
||||||
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
|
// the caller KVCacheManager::addSequence will use the minimum across all windows)
|
||||||
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
|
auto const totalPrepopulatedLen = prepopulatedPromptLen + numConnectorMatchedTokens;
|
||||||
|
TLLM_LOG_DEBUG(
|
||||||
|
"%s::addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
|
||||||
|
mLogPrefix.c_str(), llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
|
||||||
|
return totalPrepopulatedLen;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
|
void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
|
||||||
@ -1555,7 +1599,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
||||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
|
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
|
||||||
{
|
{
|
||||||
SizeType32 numBlocksStoredForReuse = 0;
|
SizeType32 numBlocksStoredForReuse = 0;
|
||||||
@ -1566,67 +1610,92 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
|
|||||||
auto searchRoot = mCachedBlocksRoot;
|
auto searchRoot = mCachedBlocksRoot;
|
||||||
bool needMatch = true;
|
bool needMatch = true;
|
||||||
|
|
||||||
auto numBlocks = blockKeys.size();
|
// There is no guarantee that these vectors will be the same length.
|
||||||
|
// Only iterate as long as we have valid blockKey and blockId.
|
||||||
|
auto numBlocks = std::min(blockKeys.size(), blockIds.size());
|
||||||
std::vector<BlockPtr> storedBlocks;
|
std::vector<BlockPtr> storedBlocks;
|
||||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||||
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
|
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
|
||||||
{
|
{
|
||||||
auto const bid = blockIds[blockCnt];
|
try
|
||||||
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
|
|
||||||
auto& block = mAllBlocksById[bid];
|
|
||||||
auto const& blockKey = blockKeys[blockCnt];
|
|
||||||
|
|
||||||
auto [partialMatch, numMatched, matchedBlock]
|
|
||||||
= needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr);
|
|
||||||
if (matchedBlock != nullptr)
|
|
||||||
{
|
{
|
||||||
// Found match
|
// Protect against blockIds being shorter than blockKeys.
|
||||||
TLLM_LOG_DEBUG(
|
auto const bid = blockIds.at(blockCnt);
|
||||||
"%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId());
|
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
|
||||||
searchRoot = matchedBlock;
|
// We set blockId to an invalid value to indicate that a block has been released early for a limited
|
||||||
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
|
// attention layer. Make sure we don't store an invalid block because of this.
|
||||||
// block can be freed and inserted at mFreePrimaryBlocks.begin()
|
auto& block = mAllBlocksById.at(bid);
|
||||||
}
|
// Protect against blockKeys being shorter than blockIds.
|
||||||
else
|
auto const& blockKey = blockKeys.at(blockCnt);
|
||||||
{
|
|
||||||
// No match
|
|
||||||
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(),
|
|
||||||
block->getBlockId());
|
|
||||||
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
|
|
||||||
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
|
|
||||||
needMatch = false; // no matching needed for following blocks
|
|
||||||
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
|
|
||||||
block->setPrevBlock(searchRoot);
|
|
||||||
block->setPrevBlockInSeq(searchRoot);
|
|
||||||
searchRoot->addNextBlock(blockKey, block);
|
|
||||||
|
|
||||||
// Sanity check. The list of stored blocks should be connected.
|
// If either of the above error conditions occur, std::vector::at will throw an exception, which is caught
|
||||||
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());
|
// further down. This will prevent an invalid block from being stored for reuse. The catch clause exits loop
|
||||||
|
// early, preventing blocks following an invalid block from being reused.
|
||||||
|
|
||||||
storedBlocks.push_back(block);
|
auto [partialMatch, numMatched, matchedBlock] = needMatch
|
||||||
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|
? searchRoot->findMatchingBlock(blockKey, false, false)
|
||||||
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
|
: std::make_tuple(false, 0, nullptr);
|
||||||
auto oldHash = block->getHash();
|
if (matchedBlock != nullptr)
|
||||||
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
|
|
||||||
if (oldHash != newHash)
|
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
|
// Found match
|
||||||
block->setHash(newHash);
|
TLLM_LOG_DEBUG("%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(),
|
||||||
|
matchedBlock->getBlockId());
|
||||||
|
searchRoot = matchedBlock;
|
||||||
|
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
|
||||||
|
// block can be freed and inserted at mFreePrimaryBlocks.begin()
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// No match
|
||||||
|
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure",
|
||||||
|
mLogPrefix.c_str(), block->getBlockId());
|
||||||
|
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
|
||||||
|
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
|
||||||
|
needMatch = false; // no matching needed for following blocks
|
||||||
|
|
||||||
|
if (block->getPrevBlock() != nullptr)
|
||||||
|
{
|
||||||
|
block->getPrevBlock()->removeNextBlock(block->getBlockKey());
|
||||||
|
}
|
||||||
|
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
|
||||||
|
block->setPrevBlock(searchRoot);
|
||||||
|
block->setPrevBlockInSeq(searchRoot);
|
||||||
|
searchRoot->addNextBlock(blockKey, block);
|
||||||
|
|
||||||
|
// Sanity check. The list of stored blocks should be connected.
|
||||||
|
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());
|
||||||
|
|
||||||
|
storedBlocks.push_back(block);
|
||||||
|
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|
||||||
|
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
|
||||||
|
auto oldHash = block->getHash();
|
||||||
|
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
|
||||||
|
if (oldHash != newHash)
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
|
||||||
|
block->setHash(newHash);
|
||||||
|
}
|
||||||
|
searchRoot = block;
|
||||||
|
numBlocksStoredForReuse++;
|
||||||
|
}
|
||||||
|
if (pinBlocks)
|
||||||
|
{
|
||||||
|
searchRoot->incRefCount();
|
||||||
|
pinnedBlockIds.push_back(searchRoot->getBlockId());
|
||||||
}
|
}
|
||||||
searchRoot = block;
|
|
||||||
numBlocksStoredForReuse++;
|
|
||||||
}
|
}
|
||||||
if (pinBlocks)
|
catch (std::out_of_range const& ex)
|
||||||
{
|
{
|
||||||
searchRoot->incRefCount();
|
TLLM_LOG_WARNING("Out of range access, terminating storeBlocks early.");
|
||||||
|
// Prevent blocks following an invalid block from being reused.
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
lastStoredId = searchRoot->getBlockId();
|
|
||||||
}
|
}
|
||||||
if (mEventManager)
|
if (mEventManager)
|
||||||
{
|
{
|
||||||
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
|
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
|
||||||
}
|
}
|
||||||
return {numBlocksStoredForReuse, lastStoredId};
|
return {numBlocksStoredForReuse, pinnedBlockIds};
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
|
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
|
||||||
@ -1714,15 +1783,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
|
|||||||
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
|
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||||
{
|
{
|
||||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||||
for (auto& [_, manager] : mWindowBlockManagers)
|
for (auto& [_, manager] : mWindowBlockManagers)
|
||||||
{
|
{
|
||||||
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||||
}
|
}
|
||||||
return lastStoredId;
|
return pinnedBlockIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
||||||
@ -1731,9 +1800,22 @@ std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
|||||||
// Released block will be stored when reuse is enabled.
|
// Released block will be stored when reuse is enabled.
|
||||||
// Reuse is implied to be enabled if llmRequest is provided.
|
// Reuse is implied to be enabled if llmRequest is provided.
|
||||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
||||||
|
|
||||||
|
// For now, the attention kernel only accepts a single
|
||||||
|
// "prepopulatedPromptLen", that is, all window sizes will use the same
|
||||||
|
// prepopulated prompt length, so it is meaningless right now to save
|
||||||
|
// blocks only for a certain window size while blocks in the other
|
||||||
|
// window size are not valid for saving for reuse.
|
||||||
|
bool isAllWindowSizesValidForStoreForReuse = true;
|
||||||
|
for (auto& [windowSize, manager] : mWindowBlockManagers)
|
||||||
|
{
|
||||||
|
isAllWindowSizesValidForStoreForReuse &= manager.isSequenceValidForStoreForReuse(sequence.getRequestId());
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& [_, manager] : mWindowBlockManagers)
|
for (auto& [_, manager] : mWindowBlockManagers)
|
||||||
{
|
{
|
||||||
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1)
|
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1
|
||||||
|
|| !isAllWindowSizesValidForStoreForReuse)
|
||||||
{
|
{
|
||||||
lastStoredId = manager.releaseBlocks(sequence, std::nullopt);
|
lastStoredId = manager.releaseBlocks(sequence, std::nullopt);
|
||||||
}
|
}
|
||||||
@ -1753,7 +1835,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||||
{
|
{
|
||||||
// Use the first window size
|
// Use the first window size
|
||||||
if (mWindowBlockManagers.empty())
|
if (mWindowBlockManagers.empty())
|
||||||
@ -1761,7 +1843,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& firstManager = mWindowBlockManagers.begin()->second;
|
auto& firstManager = mWindowBlockManagers.begin()->second;
|
||||||
firstManager.unpinBlocksById(blockId);
|
firstManager.unpinBlocksById(blockIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
||||||
@ -1774,21 +1856,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||||
{
|
{
|
||||||
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
|
if (blockIds.empty())
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto block = mAllBlocksById[blockId];
|
|
||||||
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
for (auto const& blockId : blockIds)
|
||||||
{
|
{
|
||||||
block->decRefCount();
|
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
|
||||||
if (!block->hasRefs())
|
"Block id %d is out of range", blockId);
|
||||||
|
auto block = mAllBlocksById[blockId];
|
||||||
|
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
||||||
{
|
{
|
||||||
mEvictionPolicy->releaseBlock(block);
|
block->decRefCount();
|
||||||
|
if (!block->hasRefs())
|
||||||
|
{
|
||||||
|
mEvictionPolicy->releaseBlock(block);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
block = std::move(block->getPrevBlock());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1856,7 +1943,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
|
|||||||
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||||
{
|
{
|
||||||
auto constexpr beamIdx = 0;
|
auto constexpr beamIdx = 0;
|
||||||
@ -1869,7 +1956,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
|||||||
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
|
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
|
||||||
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
|
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
|
||||||
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
||||||
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
|
|
||||||
|
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
|
||||||
|
|
||||||
|
return pinnedBlockIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
||||||
@ -1908,7 +1998,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
|||||||
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
|
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
|
||||||
[](BlockPtr const& block) { return block->getBlockId(); });
|
[](BlockPtr const& block) { return block->getBlockId(); });
|
||||||
|
|
||||||
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
||||||
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
|
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
|
||||||
sequence.getRequestId(), numBlocksStoredForReuse);
|
sequence.getRequestId(), numBlocksStoredForReuse);
|
||||||
}
|
}
|
||||||
@ -2376,6 +2466,9 @@ void KVCacheManager::addSequence(
|
|||||||
"[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization",
|
"[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization",
|
||||||
requestId);
|
requestId);
|
||||||
}
|
}
|
||||||
|
// Track the minimum prepopulated length across all windows (for VSWA with mixed isSWA flags)
|
||||||
|
SizeType32 minPrepopulatedPromptLen = std::numeric_limits<SizeType32>::max();
|
||||||
|
|
||||||
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
|
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
|
||||||
{
|
{
|
||||||
// NOTE: Caller to KVCacheManager::addSequence should deal with the chunking
|
// NOTE: Caller to KVCacheManager::addSequence should deal with the chunking
|
||||||
@ -2387,7 +2480,11 @@ void KVCacheManager::addSequence(
|
|||||||
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
|
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
|
||||||
if (mEnableBlockReuse)
|
if (mEnableBlockReuse)
|
||||||
{
|
{
|
||||||
mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
|
auto const prepopulatedLen
|
||||||
|
= mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
|
||||||
|
// Use the minimum prepopulated length across all windows to ensure correctness
|
||||||
|
// when there's a mix of SWA and non-SWA windows (e.g., VSWA case)
|
||||||
|
minPrepopulatedPromptLen = std::min(minPrepopulatedPromptLen, prepopulatedLen);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -2406,6 +2503,13 @@ void KVCacheManager::addSequence(
|
|||||||
mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize);
|
mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the prepopulated prompt length once using the minimum across all windows
|
||||||
|
if (llmRequest && mEnableBlockReuse)
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG("KVCacheManager::addSequence: Setting prepopulatedPromptLen to %d", minPrepopulatedPromptLen);
|
||||||
|
llmRequest->setPrepopulatedPromptLen(minPrepopulatedPromptLen, getTokensPerBlock());
|
||||||
|
}
|
||||||
|
|
||||||
if (llmRequest)
|
if (llmRequest)
|
||||||
{
|
{
|
||||||
// Update statistics for block allocations/reuse per request.
|
// Update statistics for block allocations/reuse per request.
|
||||||
@ -2485,15 +2589,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
|
|||||||
return lastStoredId;
|
return lastStoredId;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
||||||
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||||
auto& sequence = getSequence(requestId);
|
auto& sequence = getSequence(requestId);
|
||||||
std::optional<KVCacheBlock::IdType> lastStoredId
|
auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||||
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
|
||||||
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||||
return lastStoredId;
|
return pinnedBlockIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
|
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
|
||||||
@ -2508,9 +2611,29 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
|
|||||||
mBlockManager.pinBlocks(sequence);
|
mBlockManager.pinBlocks(sequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||||
{
|
{
|
||||||
mBlockManager.unpinBlocksById(blockId);
|
mBlockManager.unpinBlocksById(blockIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
tle::RetentionPriority KVCacheManager::getPriorityByBlockId(KVCacheBlock::IdType blockId, SizeType32 windowSize) const
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
BlockPtr const& block = mBlockManager.getBlockById(blockId, windowSize);
|
||||||
|
if (block)
|
||||||
|
{
|
||||||
|
return block->getPriority();
|
||||||
|
}
|
||||||
|
TLLM_LOG_WARNING("getPriorityByBlockId: Block ID %d not found in window %d", blockId, windowSize);
|
||||||
|
return tle::KvCacheRetentionConfig::kDefaultRetentionPriority;
|
||||||
|
}
|
||||||
|
catch (std::out_of_range const& ex)
|
||||||
|
{
|
||||||
|
TLLM_LOG_WARNING(
|
||||||
|
"getPriorityByBlockId: Block ID %d or window size %d out of range: %s", blockId, windowSize, ex.what());
|
||||||
|
return tle::KvCacheRetentionConfig::kDefaultRetentionPriority;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
|
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
|
||||||
@ -2847,6 +2970,18 @@ void KVCacheManager::removeToken(RequestIdType requestId)
|
|||||||
|
|
||||||
void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths)
|
void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths)
|
||||||
{
|
{
|
||||||
|
// Check if the sequence still exists before rewinding
|
||||||
|
// In overlap mode with MTP, the request may have been terminated and removed
|
||||||
|
// from mSequences before rewindKVCache is called
|
||||||
|
{
|
||||||
|
std::scoped_lock lck(mSequencesMtx);
|
||||||
|
if (mSequences.find(requestId) == mSequences.end())
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG("Request %lu has already been removed from KV cache manager, skipping rewind", requestId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (SizeType32 si = 0; si < rewindLengths; ++si)
|
for (SizeType32 si = 0; si < rewindLengths; ++si)
|
||||||
{
|
{
|
||||||
removeToken(requestId);
|
removeToken(requestId);
|
||||||
|
|||||||
239
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp
Normal file
239
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
|
||||||
|
#include "tensorrt_llm/common/logger.h"
|
||||||
|
#include "tensorrt_llm/common/memoryUtils.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace tc = tensorrt_llm::common;
|
||||||
|
using namespace tensorrt_llm::runtime;
|
||||||
|
|
||||||
|
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||||
|
{
|
||||||
|
|
||||||
|
template <typename Func>
|
||||||
|
bool loopedReadWrite(Func&& func, ssize_t size) noexcept
|
||||||
|
{
|
||||||
|
ssize_t count = 0;
|
||||||
|
while (count < size)
|
||||||
|
{
|
||||||
|
ssize_t bytes = func(count);
|
||||||
|
if (bytes <= 0)
|
||||||
|
{
|
||||||
|
if (errno == EINTR)
|
||||||
|
{
|
||||||
|
continue; // Retry on interrupt
|
||||||
|
}
|
||||||
|
TLLM_LOG_ERROR("Disk read/write failed: %s\n", strerror(errno));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
count += bytes;
|
||||||
|
}
|
||||||
|
assert(count == size);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool writeAll(int fd, ssize_t pos, void const* data, ssize_t size) noexcept
|
||||||
|
{
|
||||||
|
return loopedReadWrite([=](ssize_t finished)
|
||||||
|
{ return pwrite(fd, static_cast<std::byte const*>(data) + finished, size - finished, pos + finished); },
|
||||||
|
size);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool readAll(int fd, ssize_t pos, void* data, ssize_t size) noexcept
|
||||||
|
{
|
||||||
|
return loopedReadWrite([=](ssize_t finished)
|
||||||
|
{ return pread(fd, static_cast<std::byte*>(data) + finished, size - finished, pos + finished); },
|
||||||
|
size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DstAddr, typename SrcAddr>
|
||||||
|
struct UserData
|
||||||
|
{
|
||||||
|
std::vector<Task<DstAddr, SrcAddr>> tasks;
|
||||||
|
ssize_t numBytes;
|
||||||
|
};
|
||||||
|
|
||||||
|
CUDA_CB void hostFnDiskToDiskCopy(void* userData) noexcept
|
||||||
|
{
|
||||||
|
// @TODO: enable multi-threading with a thread pool
|
||||||
|
using Data = UserData<DiskAddress, DiskAddress>;
|
||||||
|
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
|
||||||
|
std::vector<std::byte> buffer(data->numBytes);
|
||||||
|
bool success = true;
|
||||||
|
for (auto const& t : data->tasks)
|
||||||
|
{
|
||||||
|
success = success && readAll(t.src.fd, t.src.pos, buffer.data(), data->numBytes);
|
||||||
|
success = success && writeAll(t.dst.fd, t.dst.pos, buffer.data(), data->numBytes);
|
||||||
|
}
|
||||||
|
if (!success)
|
||||||
|
{
|
||||||
|
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToDiskCopy failed.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CB void hostFnDiskToHostCopy(void* userData) noexcept
|
||||||
|
{
|
||||||
|
// @TODO: enable multi-threading with a thread pool
|
||||||
|
using Data = UserData<MemAddress, DiskAddress>;
|
||||||
|
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
|
||||||
|
bool success = true;
|
||||||
|
for (auto const& t : data->tasks)
|
||||||
|
{
|
||||||
|
success = success && readAll(t.src.fd, t.src.pos, reinterpret_cast<void*>(t.dst), data->numBytes);
|
||||||
|
}
|
||||||
|
if (!success)
|
||||||
|
{
|
||||||
|
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToHostCopy failed.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CB void hostFnHostToDiskCopy(void* userData) noexcept
|
||||||
|
{
|
||||||
|
// @TODO: enable multi-threading with a thread pool
|
||||||
|
using Data = UserData<DiskAddress, MemAddress>;
|
||||||
|
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
|
||||||
|
bool success = true;
|
||||||
|
for (auto const& t : data->tasks)
|
||||||
|
{
|
||||||
|
success = success && writeAll(t.dst.fd, t.dst.pos, reinterpret_cast<void const*>(t.src), data->numBytes);
|
||||||
|
}
|
||||||
|
if (!success)
|
||||||
|
{
|
||||||
|
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnHostToDiskCopy failed.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CB void hostFnHostToHostCopy(void* userData) noexcept
|
||||||
|
{
|
||||||
|
// @TODO: enable multi-threading with a thread pool
|
||||||
|
using Data = UserData<MemAddress, MemAddress>;
|
||||||
|
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
|
||||||
|
for (auto const& t : data->tasks)
|
||||||
|
{
|
||||||
|
memcpy(reinterpret_cast<void*>(t.dst), reinterpret_cast<void const*>(t.src), data->numBytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUresult copyDiskToDisk(std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
using Data = UserData<DiskAddress, DiskAddress>;
|
||||||
|
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
|
||||||
|
return cuLaunchHostFunc(stream, hostFnDiskToDiskCopy, data.release());
|
||||||
|
}
|
||||||
|
|
||||||
|
CUresult copyDiskToHost(std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
using Data = UserData<MemAddress, DiskAddress>;
|
||||||
|
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
|
||||||
|
return cuLaunchHostFunc(stream, hostFnDiskToHostCopy, data.release());
|
||||||
|
}
|
||||||
|
|
||||||
|
CUresult copyHostToDisk(std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
using Data = UserData<DiskAddress, MemAddress>;
|
||||||
|
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
|
||||||
|
return cuLaunchHostFunc(stream, hostFnHostToDiskCopy, data.release());
|
||||||
|
}
|
||||||
|
|
||||||
|
CUresult copyHostToHost(std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
using Data = UserData<MemAddress, MemAddress>;
|
||||||
|
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
|
||||||
|
return cuLaunchHostFunc(stream, hostFnHostToHostCopy, data.release());
|
||||||
|
}
|
||||||
|
|
||||||
|
SizeType32 IndexMapper::addNewSequence(LlmRequest::RequestIdType requestId)
|
||||||
|
{
|
||||||
|
TLLM_CHECK(indexMap_.find(requestId) == indexMap_.end());
|
||||||
|
auto iter = freeIndices_.begin();
|
||||||
|
TLLM_CHECK_WITH_INFO(iter != freeIndices_.end(), "No free index found");
|
||||||
|
auto index = *iter;
|
||||||
|
freeIndices_.erase(iter);
|
||||||
|
indexMap_[requestId] = index;
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
|
SizeType32 IndexMapper::getIndex(LlmRequest::RequestIdType requestId)
|
||||||
|
{
|
||||||
|
auto iter = indexMap_.find(requestId);
|
||||||
|
TLLM_CHECK_WITH_INFO(iter != indexMap_.end(), "Request ID not found in IndexMapper");
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void IndexMapper::removeSequence(LlmRequest::RequestIdType requestId)
|
||||||
|
{
|
||||||
|
auto iter = indexMap_.find(requestId);
|
||||||
|
TLLM_CHECK(iter != indexMap_.end());
|
||||||
|
auto index = iter->second;
|
||||||
|
freeIndices_.insert(index);
|
||||||
|
indexMap_.erase(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor IndexMapper::getCopyIndex(
|
||||||
|
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 numContext, SizeType32 beamWidth)
|
||||||
|
{
|
||||||
|
int numSeqs = numContext + beamWidth * (requestIds.size() - numContext);
|
||||||
|
SizeType32 batchSize = static_cast<SizeType32>(requestIds.size());
|
||||||
|
SizeType32 idx = 0;
|
||||||
|
for (SizeType32 i = 0; i < batchSize; i++)
|
||||||
|
{
|
||||||
|
if (i < numContext)
|
||||||
|
{
|
||||||
|
copyIndex_[idx++] = this->getIndex(requestIds[i]) * maxBeamWidth_;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (SizeType32 j = 0; j < beamWidth; j++)
|
||||||
|
{
|
||||||
|
copyIndex_[idx++] = this->getIndex(requestIds[i]) * maxBeamWidth_ + j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TLLM_CHECK_WITH_INFO(idx == numSeqs, "Index mapper failed to generate copy index");
|
||||||
|
|
||||||
|
return copyIndex_.slice(0, 0, numSeqs);
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexMapper::IndexMapper(SizeType32 maxBatchSize, SizeType32 maxBeamWidth)
|
||||||
|
: maxBeamWidth_(maxBeamWidth)
|
||||||
|
{
|
||||||
|
indexMap_.reserve(maxBatchSize);
|
||||||
|
for (SizeType32 i = 0; i < maxBatchSize; i++)
|
||||||
|
{
|
||||||
|
freeIndices_.insert(i);
|
||||||
|
}
|
||||||
|
// Allocate copyIndex_ memory as pinned (page-locked) host memory
|
||||||
|
copyIndex_
|
||||||
|
= at::empty({maxBatchSize * maxBeamWidth}, at::TensorOptions().dtype(at::ScalarType::Int).pinned_memory(true));
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexMapper::~IndexMapper()
|
||||||
|
{
|
||||||
|
indexMap_.clear();
|
||||||
|
freeIndices_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||||
292
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu
Normal file
292
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "kvCacheManagerV2Utils.h"
|
||||||
|
#include "tensorrt_llm/common/assert.h"
|
||||||
|
#include "tensorrt_llm/common/cudaUtils.h"
|
||||||
|
#include "tensorrt_llm/common/envUtils.h"
|
||||||
|
#include "tensorrt_llm/common/memoryUtils.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||||
|
{
|
||||||
|
using Grain = uint4;
|
||||||
|
constexpr uint32_t ctaSize = 128;
|
||||||
|
constexpr uint32_t copyBlockCtaSize = 128;
|
||||||
|
constexpr uint32_t copyBlocknbBufs = 2;
|
||||||
|
constexpr uint32_t nbBufs = 4;
|
||||||
|
constexpr uint32_t grainBytes = sizeof(Grain);
|
||||||
|
|
||||||
|
using MMTask = Task<MemAddress, MemAddress>;
|
||||||
|
|
||||||
|
__device__ __host__ inline uint32_t divUp(uint32_t a, uint32_t b)
|
||||||
|
{
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <uint32_t N>
|
||||||
|
__global__ void batchedCopy(std::array<MMTask, N> const __grid_constant__ tasks, uint32_t nbBytes)
|
||||||
|
{
|
||||||
|
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;\n");
|
||||||
|
#endif
|
||||||
|
assert(nbBytes % sizeof(Grain) == 0);
|
||||||
|
__shared__ Grain data[nbBufs][ctaSize];
|
||||||
|
|
||||||
|
uint32_t const nbTasks = gridDim.y;
|
||||||
|
assert(nbTasks <= N);
|
||||||
|
auto const& task = tasks[blockIdx.y];
|
||||||
|
uint32_t const nbSplits = gridDim.x;
|
||||||
|
uint32_t const idxSplit = blockIdx.x;
|
||||||
|
uint32_t const tid = threadIdx.x;
|
||||||
|
|
||||||
|
constexpr uint32_t bytesPerIter = grainBytes * ctaSize;
|
||||||
|
|
||||||
|
uint32_t const totalIters = divUp(nbBytes, bytesPerIter);
|
||||||
|
uint32_t const maxItersPerCta = divUp(totalIters, nbSplits);
|
||||||
|
uint32_t const idxGrainBeg = ctaSize * maxItersPerCta * idxSplit + tid;
|
||||||
|
uint32_t const idxGrainEnd = std::min(idxGrainBeg + ctaSize * maxItersPerCta, nbBytes / grainBytes);
|
||||||
|
|
||||||
|
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||||
|
asm volatile("griddepcontrol.wait;\n");
|
||||||
|
#endif
|
||||||
|
for (uint32_t i = 0; i < maxItersPerCta + nbBufs; i++)
|
||||||
|
{
|
||||||
|
uint32_t const idxBuf = i % nbBufs;
|
||||||
|
if (i >= nbBufs)
|
||||||
|
{
|
||||||
|
uint32_t const stIter = i - nbBufs;
|
||||||
|
assert(idxBuf == (stIter % nbBufs));
|
||||||
|
Grain const& src = data[idxBuf][tid];
|
||||||
|
uint32_t const idxGrain = idxGrainBeg + ctaSize * stIter;
|
||||||
|
Grain& dst = reinterpret_cast<Grain*>(task.dst)[idxGrain];
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(nbBufs - 1) : "memory");
|
||||||
|
if (idxGrain < idxGrainEnd)
|
||||||
|
{
|
||||||
|
dst = src;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
uint32_t const ldIter = i;
|
||||||
|
Grain* const dst = &data[idxBuf][tid];
|
||||||
|
uint32_t const idxGrain = idxGrainBeg + ctaSize * ldIter;
|
||||||
|
Grain const* const src = &reinterpret_cast<Grain const*>(task.src)[idxGrain];
|
||||||
|
if (idxGrain < idxGrainEnd)
|
||||||
|
{
|
||||||
|
uint32_t const size = grainBytes;
|
||||||
|
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)),
|
||||||
|
"l"(src), "n"(grainBytes), "r"(size)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
asm volatile("cp.async.commit_group;\n" : : : "memory");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <uint32_t N>
|
||||||
|
CUresult launchBatchedCopyImpl(
|
||||||
|
bool lowBandwidth, MMTask const* tasks, uint32_t nbTasks, uint32_t nbBytes, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
TLLM_CHECK(nbTasks <= N);
|
||||||
|
TLLM_CHECK_WITH_INFO(
|
||||||
|
nbBytes % sizeof(Grain) == 0, "Not implemented case: nbBytes = %d must be a multiple of 16.", nbBytes);
|
||||||
|
std::array<MMTask, N> const* pTasks;
|
||||||
|
std::array<MMTask, N> tmp;
|
||||||
|
if (nbTasks < N)
|
||||||
|
{
|
||||||
|
std::copy_n(tasks, nbTasks, tmp.begin());
|
||||||
|
pTasks = &tmp;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
pTasks = reinterpret_cast<std::array<MMTask, N> const*>(tasks);
|
||||||
|
}
|
||||||
|
uint32_t const nbSplits = lowBandwidth ? 1 : divUp(nbBytes, grainBytes * ctaSize * 2);
|
||||||
|
void* args[] = {(void*) pTasks, (void*) &nbBytes};
|
||||||
|
static CUkernel const kernel = [] -> CUkernel
|
||||||
|
{
|
||||||
|
cudaKernel_t kernel = nullptr;
|
||||||
|
TLLM_CUDA_CHECK(cudaGetKernel(&kernel, reinterpret_cast<void const*>(&batchedCopy<N>)));
|
||||||
|
return kernel;
|
||||||
|
}();
|
||||||
|
return common::CUDADriverWrapper::getInstance()->cuLaunchKernel(reinterpret_cast<CUfunction>(kernel), nbSplits,
|
||||||
|
nbTasks, 1, // gridDimX, gridDimY, gridDimZ
|
||||||
|
ctaSize, 1, 1, // blockDimX, blockDimY, blockDimZ
|
||||||
|
0, // sharedMemBytes
|
||||||
|
stream, args, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// When bandwidth is low, e.g. when host memory is involved, we avoid splitting as fewer CTAs should be enough to
|
||||||
|
// saturate the bandwidth.
|
||||||
|
CUresult launchBatchedCopy(bool lowBandwidth, std::vector<MMTask> const& tasks, uint32_t nbBytes, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
constexpr uint32_t maxN = 256;
|
||||||
|
uint32_t const nbWholeBatches = tasks.size() / maxN;
|
||||||
|
for (uint32_t i = 0; i < nbWholeBatches; i++)
|
||||||
|
{
|
||||||
|
CUresult const err = launchBatchedCopyImpl<maxN>(lowBandwidth, tasks.data() + maxN * i, maxN, nbBytes, stream);
|
||||||
|
if (err != CUDA_SUCCESS)
|
||||||
|
{
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto const* const pTasks = tasks.data() + maxN * nbWholeBatches;
|
||||||
|
auto const batchSize = tasks.size() % maxN;
|
||||||
|
if (batchSize == 0)
|
||||||
|
{
|
||||||
|
return CUDA_SUCCESS;
|
||||||
|
}
|
||||||
|
if (batchSize > maxN / 2)
|
||||||
|
{
|
||||||
|
return launchBatchedCopyImpl<maxN>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
|
||||||
|
}
|
||||||
|
if (batchSize > maxN / 4)
|
||||||
|
{
|
||||||
|
return launchBatchedCopyImpl<maxN / 2>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
|
||||||
|
}
|
||||||
|
if (batchSize > maxN / 8)
|
||||||
|
{
|
||||||
|
return launchBatchedCopyImpl<maxN / 4>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
|
||||||
|
}
|
||||||
|
return launchBatchedCopyImpl<maxN / 8>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUresult copyHostToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
return launchBatchedCopy(true, tasks, numBytes, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUresult copyDeviceToHost(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
return launchBatchedCopy(true, tasks, numBytes, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
return launchBatchedCopy(false, tasks, numBytes, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// dst_tensor[:, :num_seqs, 0] = src_tensor[:, copy_idx]
|
||||||
|
// dst_tensor[:, :num_seqs, 1] = dst_tensor[:, :num_seqs, 0] + 1
|
||||||
|
__global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict__ srcPtr,
|
||||||
|
SizeType32* __restrict__ dstPtr, SizeType32 const srcMaxNumSequences, SizeType32 const dstMaxNumSequences,
|
||||||
|
SizeType32 numBlocksPerSeq, SizeType32 const* __restrict__ copyIndex, SizeType32 const* __restrict__ indexScales,
|
||||||
|
SizeType32 const* __restrict__ kvOffset)
|
||||||
|
{
|
||||||
|
constexpr uint32_t kvFactor = 2;
|
||||||
|
constexpr auto elemPerAccess = sizeof(PackedInt) / sizeof(SizeType32);
|
||||||
|
|
||||||
|
__shared__ PackedInt data[copyBlocknbBufs][copyBlockCtaSize];
|
||||||
|
|
||||||
|
auto const iterPerSeq = divUp(numBlocksPerSeq * sizeof(SizeType32), sizeof(PackedInt) * copyBlockCtaSize);
|
||||||
|
auto const tid = threadIdx.x;
|
||||||
|
auto const poolIdx = blockIdx.x;
|
||||||
|
auto const seqIdx = blockIdx.y;
|
||||||
|
auto const seqDimStride = kvFactor * numBlocksPerSeq;
|
||||||
|
uint32_t const srcIdxBeg = tid * elemPerAccess + (poolIdx * srcMaxNumSequences + copyIndex[seqIdx]) * seqDimStride;
|
||||||
|
uint32_t const dstIdxKBeg = tid * elemPerAccess + (poolIdx * dstMaxNumSequences + seqIdx) * seqDimStride;
|
||||||
|
uint32_t const dstIdxVBeg = dstIdxKBeg + numBlocksPerSeq;
|
||||||
|
|
||||||
|
uint32_t const srcIdxEnd = (poolIdx * srcMaxNumSequences + copyIndex[seqIdx]) * seqDimStride + numBlocksPerSeq;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < iterPerSeq + copyBlocknbBufs; i++)
|
||||||
|
{
|
||||||
|
uint32_t const idxBuf = i % copyBlocknbBufs;
|
||||||
|
if (i >= copyBlocknbBufs)
|
||||||
|
{
|
||||||
|
uint32_t const stIter = i - copyBlocknbBufs;
|
||||||
|
assert(idxBuf == (stIter % copyBlocknbBufs));
|
||||||
|
auto const offset = copyBlockCtaSize * stIter * elemPerAccess;
|
||||||
|
SizeType32 const srcIdx = srcIdxBeg + offset;
|
||||||
|
SizeType32 const dstIdxK = dstIdxKBeg + offset;
|
||||||
|
SizeType32 const dstIdxV = dstIdxVBeg + offset;
|
||||||
|
PackedInt const& src = data[idxBuf][tid];
|
||||||
|
PackedInt& dstK = *reinterpret_cast<PackedInt*>(dstPtr + dstIdxK);
|
||||||
|
PackedInt& dstV = *reinterpret_cast<PackedInt*>(dstPtr + dstIdxV);
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(copyBlocknbBufs - 1) : "memory");
|
||||||
|
if (srcIdx < srcIdxEnd)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < elemPerAccess; j++)
|
||||||
|
{
|
||||||
|
auto const val = src.unpacked[j];
|
||||||
|
dstK.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val);
|
||||||
|
dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val + kvOffset[poolIdx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
uint32_t const ldIter = i;
|
||||||
|
PackedInt* const dst = &data[idxBuf][tid];
|
||||||
|
uint32_t const srcIdx = srcIdxBeg + copyBlockCtaSize * ldIter * elemPerAccess;
|
||||||
|
PackedInt const* const src = reinterpret_cast<PackedInt const*>(srcPtr + srcIdx);
|
||||||
|
if (srcIdx < srcIdxEnd)
|
||||||
|
{
|
||||||
|
uint32_t const size = sizeof(PackedInt);
|
||||||
|
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)),
|
||||||
|
"l"(src), "n"(size), "r"(size)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
asm volatile("cp.async.commit_group;\n" : : : "memory");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host-side launcher
|
||||||
|
void copyBatchBlockOffsetsToDevice(ITensor const& input, ITensor& output, ITensor const& copyIndex,
|
||||||
|
ITensor const& indexScales, ITensor const& kvOffset, CUstream stream) noexcept
|
||||||
|
{
|
||||||
|
using namespace tensorrt_llm::runtime;
|
||||||
|
|
||||||
|
auto const* srcPtr = bufferCast<tk::KVCacheIndex::UnderlyingType const>(input);
|
||||||
|
auto* dstPtr = bufferCast<tk::KVCacheIndex::UnderlyingType>(
|
||||||
|
output); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq]
|
||||||
|
auto const* copyIndexPtr = bufferCast<SizeType32 const>(copyIndex);
|
||||||
|
auto const* indexScalesPtr = bufferCast<SizeType32 const>(indexScales);
|
||||||
|
auto const* kvOffsetPtr = bufferCast<SizeType32 const>(kvOffset);
|
||||||
|
auto const& srcShape = input.getShape();
|
||||||
|
auto const& dstShape = output.getShape();
|
||||||
|
auto const& copyIndexShape = copyIndex.getShape();
|
||||||
|
|
||||||
|
TLLM_CHECK(srcShape.nbDims == 4); // [numPools, srcMaxNumSequences, kvFactor, numBlocksPerSeq]
|
||||||
|
TLLM_CHECK(dstShape.nbDims == 4); // [numPools, dstMaxNumSequences, kvFactor, numBlocksPerSeq]
|
||||||
|
|
||||||
|
SizeType32 numPools = srcShape.d[0];
|
||||||
|
SizeType32 srcMaxNumSequences = srcShape.d[1];
|
||||||
|
SizeType32 dstMaxNumSequences = dstShape.d[1];
|
||||||
|
SizeType32 numBlocksPerSeq = srcShape.d[3];
|
||||||
|
SizeType32 numSeqs = copyIndexShape.d[0];
|
||||||
|
|
||||||
|
if (numSeqs == 0)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
TLLM_CHECK_WITH_INFO((numBlocksPerSeq * sizeof(SizeType32)) % sizeof(PackedInt) == 0,
|
||||||
|
"Not implemented case: numBlocksPerSeq * sizeof(SizeType32) = %zu must be a multiple of %zu.",
|
||||||
|
static_cast<size_t>(numBlocksPerSeq * sizeof(SizeType32)), static_cast<size_t>(sizeof(PackedInt)));
|
||||||
|
|
||||||
|
dim3 gridDim(numPools, numSeqs, 1);
|
||||||
|
dim3 blockDim(copyBlockCtaSize);
|
||||||
|
|
||||||
|
copyBatchBlockOffsetsToDeviceKernel<<<gridDim, blockDim, 0, stream>>>(srcPtr, dstPtr, srcMaxNumSequences,
|
||||||
|
dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr, indexScalesPtr, kvOffsetPtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||||
101
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h
Normal file
101
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||||
|
#include "tensorrt_llm/kernels/kvCacheIndex.h"
|
||||||
|
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||||
|
#include "tensorrt_llm/runtime/iTensor.h"
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <set>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace tk = tensorrt_llm::kernels;
|
||||||
|
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||||
|
using ITensor = tensorrt_llm::runtime::ITensor;
|
||||||
|
|
||||||
|
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||||
|
{
|
||||||
|
struct DiskAddress
|
||||||
|
{
|
||||||
|
int fd;
|
||||||
|
ssize_t pos;
|
||||||
|
};
|
||||||
|
|
||||||
|
using MemAddress = std::uintptr_t;
|
||||||
|
|
||||||
|
// Please make sure to align with the definition in tensorrt_llm/runtime/kv_cache_manager_v2/_common.py
|
||||||
|
constexpr tk::KVCacheIndex::UnderlyingType BAD_PAGE_INDEX = -1;
|
||||||
|
|
||||||
|
template <typename DstAddr, typename SrcAddr>
|
||||||
|
struct Task
|
||||||
|
{
|
||||||
|
DstAddr dst;
|
||||||
|
SrcAddr src;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PackedInt = union
|
||||||
|
{
|
||||||
|
int4 packed;
|
||||||
|
tk::KVCacheIndex::UnderlyingType unpacked[4];
|
||||||
|
};
|
||||||
|
|
||||||
|
class IndexMapper
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
IndexMapper(SizeType32 maxBatchSize, SizeType32 maxBeamWidth);
|
||||||
|
|
||||||
|
~IndexMapper();
|
||||||
|
|
||||||
|
IndexMapper(IndexMapper const&) = delete;
|
||||||
|
IndexMapper& operator=(IndexMapper const&) = delete;
|
||||||
|
|
||||||
|
SizeType32 addNewSequence(LlmRequest::RequestIdType requestId);
|
||||||
|
|
||||||
|
SizeType32 getIndex(LlmRequest::RequestIdType requestId);
|
||||||
|
|
||||||
|
void removeSequence(LlmRequest::RequestIdType requestId);
|
||||||
|
|
||||||
|
at::Tensor getCopyIndex(
|
||||||
|
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 numContext, SizeType32 beamWidth);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<LlmRequest::RequestIdType, SizeType32> indexMap_;
|
||||||
|
std::set<SizeType32> freeIndices_;
|
||||||
|
SizeType32 maxBeamWidth_;
|
||||||
|
at::Tensor copyIndex_;
|
||||||
|
};
|
||||||
|
|
||||||
|
CUresult copyDiskToDisk(std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||||
|
CUresult copyDiskToHost(std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||||
|
CUresult copyHostToDisk(std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||||
|
CUresult copyHostToHost(std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||||
|
CUresult copyHostToDevice(
|
||||||
|
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||||
|
CUresult copyDeviceToHost(
|
||||||
|
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||||
|
CUresult copyDeviceToDevice(
|
||||||
|
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||||
|
|
||||||
|
void copyBatchBlockOffsetsToDevice(ITensor const& input, ITensor& output, ITensor const& copyIndex,
|
||||||
|
ITensor const& indexScales, ITensor const& kvOffset, CUstream stream) noexcept;
|
||||||
|
|
||||||
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||||
@ -99,13 +99,15 @@ std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int
|
|||||||
}
|
}
|
||||||
if (!hasDraftTokens())
|
if (!hasDraftTokens())
|
||||||
{
|
{
|
||||||
result.contextPhaseParams = executor::ContextPhaseParams{
|
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
|
||||||
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), std::nullopt};
|
mContextPhaseParams.value().releaseState(), std::nullopt, mContextPhaseParams.value().getCtxDpRank(),
|
||||||
|
mContextPhaseParams.value().getDisaggInfoEndpoint()};
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
result.contextPhaseParams = executor::ContextPhaseParams{
|
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
|
||||||
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), *getDraftTokens()};
|
mContextPhaseParams.value().releaseState(), *getDraftTokens(),
|
||||||
|
mContextPhaseParams.value().getCtxDpRank(), mContextPhaseParams.value().getDisaggInfoEndpoint()};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,8 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
|
|||||||
bool MLACacheFormatter::needSendCache(
|
bool MLACacheFormatter::needSendCache(
|
||||||
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
|
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
|
||||||
{
|
{
|
||||||
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
|
||||||
|
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
|
||||||
|
|
||||||
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
|
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
|
||||||
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
|
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
|
||||||
@ -356,8 +357,8 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
|
|||||||
auto& bufferManager = session.getBufferManager();
|
auto& bufferManager = session.getBufferManager();
|
||||||
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
||||||
bool const recvSideHasCP = selfConfig.getParallelConfig().mContextParallelism > 1;
|
bool const recvSideHasCP = selfConfig.getParallelConfig().mContextParallelism > 1;
|
||||||
auto blockRange
|
auto blockRange = getBlockRangeForReceiving(
|
||||||
= getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), recvSideHasCP);
|
mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), destConfig.getEnablePartialReuse(), recvSideHasCP);
|
||||||
auto const numPools = mCacheManager->getBlockManager().getNumPools(
|
auto const numPools = mCacheManager->getBlockManager().getNumPools(
|
||||||
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
|
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
|
||||||
auto const& windowSizes = blockRange.getWindowSizes();
|
auto const& windowSizes = blockRange.getWindowSizes();
|
||||||
|
|||||||
@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bo
|
|||||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>>
|
std::tuple<std::unordered_map<uint64_t, std::future<void>>, BasePeftCacheManager::TaskIdToReqIds>
|
||||||
PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests)
|
PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests)
|
||||||
{
|
{
|
||||||
std::map<uint64_t, std::vector<uint64_t>> taskIdToReqIds;
|
TaskIdToReqIds taskIdToReqIds;
|
||||||
std::map<uint64_t, std::future<void>> taskIdToFuture;
|
std::unordered_map<uint64_t, std::future<void>> taskIdToFuture;
|
||||||
std::lock_guard<std::mutex> futuresLock(mPutFuturesMutex);
|
std::lock_guard<std::mutex> futuresLock(mPutFuturesMutex);
|
||||||
for (auto const& requests : {contextRequests, generationRequests})
|
for (auto const& requests : {contextRequests, generationRequests})
|
||||||
{
|
{
|
||||||
@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto
|
|||||||
return {std::move(taskIdToFuture), taskIdToReqIds};
|
return {std::move(taskIdToFuture), taskIdToReqIds};
|
||||||
}
|
}
|
||||||
|
|
||||||
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId(
|
||||||
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
|
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||||
@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
|||||||
auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests);
|
auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests);
|
||||||
auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension
|
auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension
|
||||||
|
|
||||||
std::map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
|
std::unordered_map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
|
||||||
for (auto& [taskId, taskFuture] : taskIdToFuture)
|
for (auto& [taskId, taskFuture] : taskIdToFuture)
|
||||||
{
|
{
|
||||||
auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
|
auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
|
||||||
@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
|||||||
ensureFutures.try_emplace(taskId, std::move(f));
|
ensureFutures.try_emplace(taskId, std::move(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
PeftTable peftTable{};
|
TaskPeftTable peftTable{};
|
||||||
for (auto const& [taskId, reqIds] : taskIdToReqIds)
|
for (auto const& [taskId, reqIds] : taskIdToReqIds)
|
||||||
{
|
{
|
||||||
auto&& f = ensureFutures.at(taskId);
|
auto&& f = ensureFutures.at(taskId);
|
||||||
auto const values = f.get();
|
auto const values = f.get();
|
||||||
for (auto const& reqId : reqIds)
|
peftTable.try_emplace(taskId, values);
|
||||||
{
|
|
||||||
peftTable.try_emplace(reqId, values);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||||
return peftTable;
|
return {std::move(peftTable), std::move(taskIdToReqIds)};
|
||||||
|
}
|
||||||
|
|
||||||
|
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
||||||
|
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
|
||||||
|
{
|
||||||
|
auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId(contextRequests, generationRequests, resetGpuCache);
|
||||||
|
PeftTable requestTable{};
|
||||||
|
for (auto const& [taskId, values] : taskTable)
|
||||||
|
{
|
||||||
|
auto const& reqIds = taskIdToReqIds.at(taskId);
|
||||||
|
for (auto const reqId : reqIds)
|
||||||
|
{
|
||||||
|
requestTable.try_emplace(reqId, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return requestTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PeftCacheManager::isTaskCached(uint64_t taskId) const
|
bool PeftCacheManager::isTaskCached(uint64_t taskId) const
|
||||||
@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
|
|||||||
return mDeviceLoraCache->isDone(taskId);
|
return mDeviceLoraCache->isDone(taskId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const
|
||||||
|
{
|
||||||
|
return mDeviceLoraCache->has(taskId);
|
||||||
|
}
|
||||||
|
|
||||||
void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
|
void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
|
||||||
{
|
{
|
||||||
if (!terminate)
|
if (!terminate)
|
||||||
@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
} // namespace tensorrt_llm::batch_manager
|
} // namespace tensorrt_llm::batch_manager
|
||||||
|
|
||||||
|
// TODO: merge C++ LoRA caching status with Py Slot manager
|
||||||
|
|||||||
@ -17,8 +17,11 @@
|
|||||||
|
|
||||||
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
|
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
|
||||||
#include "tensorrt_llm/common/assert.h"
|
#include "tensorrt_llm/common/assert.h"
|
||||||
|
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||||
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
|
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
using namespace tensorrt_llm::runtime;
|
using namespace tensorrt_llm::runtime;
|
||||||
|
|
||||||
namespace tensorrt_llm::batch_manager::rnn_state_manager
|
namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||||
@ -82,6 +85,64 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RnnStateManager::RnnStateManager(SizeType32 dState, SizeType32 dConv, SizeType32 numHeads, SizeType32 nGroups,
|
||||||
|
SizeType32 headDim, SizeType32 maxBatchSize, WorldConfig const& worldConfig, int64_t stream,
|
||||||
|
nvinfer1::DataType dtype, nvinfer1::DataType ssmCacheDtype, std::vector<SizeType32> const& ppLayers)
|
||||||
|
: mMaxNumSequences(maxBatchSize)
|
||||||
|
, mMaxBeamWidth{1}
|
||||||
|
, mBeamSlotsPerSequence{1}
|
||||||
|
, mBufferManager{std::make_shared<CudaStream>(reinterpret_cast<cudaStream_t>(stream))}
|
||||||
|
{
|
||||||
|
auto const tpSize = worldConfig.getTensorParallelism();
|
||||||
|
|
||||||
|
auto const dInner = headDim * numHeads;
|
||||||
|
auto convDim = dInner + 2 * nGroups * dState;
|
||||||
|
auto nheads = numHeads;
|
||||||
|
|
||||||
|
TLLM_CHECK_WITH_INFO(nheads % tpSize == 0, "nheads must be divisible by tp_size");
|
||||||
|
TLLM_CHECK_WITH_INFO(convDim % tpSize == 0, "conv_dim must be divisible by tp_size");
|
||||||
|
|
||||||
|
convDim = convDim / tpSize;
|
||||||
|
nheads = nheads / tpSize;
|
||||||
|
|
||||||
|
auto const numLocalLayers = static_cast<SizeType32>(ppLayers.size());
|
||||||
|
|
||||||
|
for (SizeType32 offset = 0; offset < numLocalLayers; ++offset)
|
||||||
|
{
|
||||||
|
mLayerOffsets[ppLayers[offset]] = offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const convStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, convDim, dConv - 1});
|
||||||
|
pagedConvStates = mBufferManager->gpu(convStateShape, dtype);
|
||||||
|
|
||||||
|
auto const rnnStateShape = ITensor::makeShape({numLocalLayers, maxBatchSize, nheads, headDim, dState});
|
||||||
|
pagedRnnStates = mBufferManager->gpu(rnnStateShape, ssmCacheDtype);
|
||||||
|
|
||||||
|
mFreeBlocks.reserve(maxBatchSize);
|
||||||
|
for (SizeType32 i = 0; i < maxBatchSize; ++i)
|
||||||
|
{
|
||||||
|
mFreeBlocks.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const statePtrsShape = ITensor::makeShape({numLocalLayers});
|
||||||
|
rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
||||||
|
convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
||||||
|
auto* rnnStatePtrArray = bufferCast<void*>(*rnnStatePtrs);
|
||||||
|
auto* convStatePtrArray = bufferCast<void*>(*convStatePtrs);
|
||||||
|
|
||||||
|
rnnStatePtr.resize(numLocalLayers);
|
||||||
|
convStatePtr.resize(numLocalLayers);
|
||||||
|
for (SizeType32 i = 0; i < numLocalLayers; i++)
|
||||||
|
{
|
||||||
|
auto layerRnnStates = ITensor::slice(pagedRnnStates, i, 1);
|
||||||
|
auto layerConvStates = ITensor::slice(pagedConvStates, i, 1);
|
||||||
|
rnnStatePtrArray[i] = layerRnnStates->data();
|
||||||
|
convStatePtrArray[i] = layerConvStates->data();
|
||||||
|
rnnStatePtr[i] = ITensor::slice(rnnStatePtrs, i, 1);
|
||||||
|
convStatePtr[i] = ITensor::slice(convStatePtrs, i, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void RnnStateManager::getPtrBuffers(
|
void RnnStateManager::getPtrBuffers(
|
||||||
TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const
|
TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const
|
||||||
{
|
{
|
||||||
@ -113,4 +174,95 @@ void RnnStateManager::fillSlotMapping(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RnnStateManager::allocateCacheBlocks(std::vector<RequestIdType> const& requestIds)
|
||||||
|
{
|
||||||
|
for (auto const& requestId : requestIds)
|
||||||
|
{
|
||||||
|
auto it = mCacheIndex.find(requestId);
|
||||||
|
if (it == mCacheIndex.end())
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks");
|
||||||
|
SizeType32 const block = mFreeBlocks.back();
|
||||||
|
mFreeBlocks.pop_back();
|
||||||
|
mCacheIndex[requestId] = block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RnnStateManager::freeCacheBlock(RequestIdType requestId)
|
||||||
|
{
|
||||||
|
auto it = mCacheIndex.find(requestId);
|
||||||
|
if (it != mCacheIndex.end())
|
||||||
|
{
|
||||||
|
mFreeBlocks.push_back(it->second);
|
||||||
|
mCacheIndex.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RnnStateManager::SizeType32 RnnStateManager::getCacheIndex(RequestIdType requestId) const
|
||||||
|
{
|
||||||
|
auto it = mCacheIndex.find(requestId);
|
||||||
|
TLLM_CHECK_WITH_INFO(it != mCacheIndex.end(), "Request ID not found in cache index");
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<RnnStateManager::SizeType32> RnnStateManager::getStateIndices(
|
||||||
|
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding)
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size");
|
||||||
|
|
||||||
|
std::unordered_set<SizeType32> availableSlots;
|
||||||
|
availableSlots.reserve(mMaxNumSequences);
|
||||||
|
for (SizeType32 i = 0; i < mMaxNumSequences; ++i)
|
||||||
|
{
|
||||||
|
availableSlots.insert(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < requestIds.size(); ++i)
|
||||||
|
{
|
||||||
|
if (!isPadding[i])
|
||||||
|
{
|
||||||
|
availableSlots.erase(getCacheIndex(requestIds[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<SizeType32> result;
|
||||||
|
result.reserve(requestIds.size());
|
||||||
|
auto availableIt = availableSlots.begin();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < requestIds.size(); ++i)
|
||||||
|
{
|
||||||
|
if (isPadding[i])
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding");
|
||||||
|
result.push_back(*availableIt);
|
||||||
|
++availableIt;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
result.push_back(getCacheIndex(requestIds[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
RnnStateManager::TensorPtr RnnStateManager::getConvStates(SizeType32 layerIdx) const
|
||||||
|
{
|
||||||
|
auto it = mLayerOffsets.find(layerIdx);
|
||||||
|
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
|
||||||
|
auto result = ITensor::slice(pagedConvStates, it->second, 1);
|
||||||
|
result->squeeze(0);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
RnnStateManager::TensorPtr RnnStateManager::getSsmStates(SizeType32 layerIdx) const
|
||||||
|
{
|
||||||
|
auto it = mLayerOffsets.find(layerIdx);
|
||||||
|
TLLM_CHECK_WITH_INFO(it != mLayerOffsets.end(), "Layer index not found in layer offsets");
|
||||||
|
auto result = ITensor::slice(pagedRnnStates, it->second, 1);
|
||||||
|
result->squeeze(0);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorrt_llm::batch_manager::rnn_state_manager
|
} // namespace tensorrt_llm::batch_manager::rnn_state_manager
|
||||||
|
|||||||
@ -503,7 +503,7 @@ TrtGptModelInflightBatching::~TrtGptModelInflightBatching()
|
|||||||
{
|
{
|
||||||
if (mCacheTransceiver)
|
if (mCacheTransceiver)
|
||||||
{
|
{
|
||||||
mCacheTransceiver->checkContextTransferStatus(true);
|
mCacheTransceiver->checkContextTransferStatus(1, true);
|
||||||
TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete");
|
TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete");
|
||||||
}
|
}
|
||||||
if (mAsyncSendWaitThread)
|
if (mAsyncSendWaitThread)
|
||||||
@ -932,7 +932,7 @@ void TrtGptModelInflightBatching::forwardSync()
|
|||||||
}
|
}
|
||||||
if (mCacheTransceiver)
|
if (mCacheTransceiver)
|
||||||
{
|
{
|
||||||
mCacheTransceiver->checkContextTransferStatus(0);
|
mCacheTransceiver->checkContextTransferStatus(0, true);
|
||||||
}
|
}
|
||||||
++mIterCounter;
|
++mIterCounter;
|
||||||
|
|
||||||
@ -1025,7 +1025,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
|
|||||||
mIterCounter);
|
mIterCounter);
|
||||||
if (mCacheTransceiver)
|
if (mCacheTransceiver)
|
||||||
{
|
{
|
||||||
mCacheTransceiver->checkContextTransferStatus(1);
|
mCacheTransceiver->checkContextTransferStatus(1, true);
|
||||||
// will free kvCache in next iteration.
|
// will free kvCache in next iteration.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -36,6 +36,7 @@ add_library(common_src OBJECT ${SRCS} ${CU_SRCS})
|
|||||||
add_cuda_architectures(common_src 89)
|
add_cuda_architectures(common_src 89)
|
||||||
set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||||
|
target_link_libraries(common_src PUBLIC trtllm_gen_fmha_interface)
|
||||||
|
|
||||||
if(ENABLE_CUBLASLT_FP4_GEMM)
|
if(ENABLE_CUBLASLT_FP4_GEMM)
|
||||||
target_compile_definitions(common_src PRIVATE ENABLE_CUBLASLT_FP4_GEMM)
|
target_compile_definitions(common_src PRIVATE ENABLE_CUBLASLT_FP4_GEMM)
|
||||||
|
|||||||
@ -296,7 +296,13 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
|||||||
// Parameters for sparse attention
|
// Parameters for sparse attention
|
||||||
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
|
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
|
||||||
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
|
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
|
||||||
|
// Skip softmax threshold.
|
||||||
|
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
// Statistics of skip-softmax, pointers of device memory for output
|
||||||
|
xqaParams.skip_softmax_total_blocks = mSkipSoftmaxTotalBlocks;
|
||||||
|
xqaParams.skip_softmax_skipped_blocks = mSkipSoftmaxSkippedBlocks;
|
||||||
|
#endif
|
||||||
// Cross attention parameters.
|
// Cross attention parameters.
|
||||||
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
|
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
|
||||||
|
|
||||||
@ -1035,7 +1041,7 @@ int AttentionOp::mlaGeneration(
|
|||||||
TllmGenFmhaRunnerParams tllmRunnerParams{};
|
TllmGenFmhaRunnerParams tllmRunnerParams{};
|
||||||
|
|
||||||
// Parameters to select kernels.
|
// Parameters to select kernels.
|
||||||
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
|
tllmRunnerParams.mMaskType = TrtllmGenAttentionMaskType::Causal;
|
||||||
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
|
tllmRunnerParams.mKernelType = FmhaKernelType::Generation;
|
||||||
tllmRunnerParams.mMultiCtasKvMode = mMultiBlockMode;
|
tllmRunnerParams.mMultiCtasKvMode = mMultiBlockMode;
|
||||||
// Note that the tileScheduler and multiCtasKvMode will be automatically tuned when using multi_block mode.
|
// Note that the tileScheduler and multiCtasKvMode will be automatically tuned when using multi_block mode.
|
||||||
@ -1265,14 +1271,6 @@ int AttentionOp::mlaGeneration(
|
|||||||
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
|
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
|
|
||||||
{
|
|
||||||
TLLM_CHECK_WITH_INFO(false, "No available XQA kernels are found for speculative decoding mode.");
|
|
||||||
}
|
|
||||||
else if (mFuseFp4Quant)
|
|
||||||
{
|
|
||||||
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use FMHA otherwise.
|
// Use FMHA otherwise.
|
||||||
@ -1313,6 +1311,8 @@ int AttentionOp::mlaGeneration(
|
|||||||
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
|
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MLA does not support skip-softmax attention right now
|
||||||
|
|
||||||
// Run the fmha kernel
|
// Run the fmha kernel
|
||||||
mDecoderFMHARunner->run(fmhaParams);
|
mDecoderFMHARunner->run(fmhaParams);
|
||||||
}
|
}
|
||||||
@ -1885,6 +1885,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
|||||||
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
|
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip-softmax attention parameters
|
||||||
|
fmhaParams.skipSoftmaxThresholdScaleFactor = mSkipSoftmaxThresholdScaleFactorPrefill;
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
fmhaParams.skipSoftmaxTotalBlocks = mSkipSoftmaxTotalBlocks;
|
||||||
|
fmhaParams.skipSoftmaxSkippedBlocks = mSkipSoftmaxSkippedBlocks;
|
||||||
|
#else
|
||||||
|
if (tensorrt_llm::common::getEnvPrintSkipSoftmaxStat())
|
||||||
|
{
|
||||||
|
TLLM_THROW("To print skip softmax stat, please run build_wheel.py with -DSKIP_SOFTMAX_STAT");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (mAttentionChunkSize)
|
if (mAttentionChunkSize)
|
||||||
{
|
{
|
||||||
fmhaParams.chunkedAttentionSize = *mAttentionChunkSize;
|
fmhaParams.chunkedAttentionSize = *mAttentionChunkSize;
|
||||||
|
|||||||
@ -127,7 +127,6 @@ public:
|
|||||||
public:
|
public:
|
||||||
// Attention packed mask input (used by context FMHA).
|
// Attention packed mask input (used by context FMHA).
|
||||||
uint32_t const* attention_packed_mask = nullptr;
|
uint32_t const* attention_packed_mask = nullptr;
|
||||||
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
|
|
||||||
int32_t batch_size = 0;
|
int32_t batch_size = 0;
|
||||||
float2 const* mrope_rotary_cos_sin = nullptr;
|
float2 const* mrope_rotary_cos_sin = nullptr;
|
||||||
|
|
||||||
@ -182,7 +181,6 @@ public:
|
|||||||
ss << "context_buf_sf: " << this->context_buf_sf << std::endl;
|
ss << "context_buf_sf: " << this->context_buf_sf << std::endl;
|
||||||
ss << "key_value_cache: " << (half*) this->key_value_cache << std::endl;
|
ss << "key_value_cache: " << (half*) this->key_value_cache << std::endl;
|
||||||
ss << "block_offsets: " << this->block_offsets << std::endl;
|
ss << "block_offsets: " << this->block_offsets << std::endl;
|
||||||
ss << "host_block_offsets: " << this->host_block_offsets << std::endl;
|
|
||||||
ss << "host_primary_pool_pointer: " << this->host_primary_pool_pointer << std::endl;
|
ss << "host_primary_pool_pointer: " << this->host_primary_pool_pointer << std::endl;
|
||||||
ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl;
|
ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl;
|
||||||
ss << "batch_size: " << this->batch_size << std::endl;
|
ss << "batch_size: " << this->batch_size << std::endl;
|
||||||
@ -494,6 +492,14 @@ public:
|
|||||||
// See [Chunked Attention] in _torch/modules/attention.py
|
// See [Chunked Attention] in _torch/modules/attention.py
|
||||||
std::optional<int64_t> mAttentionChunkSize = std::nullopt;
|
std::optional<int64_t> mAttentionChunkSize = std::nullopt;
|
||||||
|
|
||||||
|
// Skip softmax threshold scale factor.
|
||||||
|
float mSkipSoftmaxThresholdScaleFactorPrefill = 0;
|
||||||
|
float mSkipSoftmaxThresholdScaleFactorDecode = 0;
|
||||||
|
#ifdef SKIP_SOFTMAX_STAT
|
||||||
|
uint32_t* mSkipSoftmaxTotalBlocks;
|
||||||
|
uint32_t* mSkipSoftmaxSkippedBlocks;
|
||||||
|
#endif
|
||||||
|
|
||||||
[[nodiscard]] auto data() const
|
[[nodiscard]] auto data() const
|
||||||
{
|
{
|
||||||
return std::make_tuple(mLayerIdx, mNumHeads, mVisionStart, mVisionLength, mNumKVHeads, mHeadSize,
|
return std::make_tuple(mLayerIdx, mNumHeads, mVisionStart, mVisionLength, mNumKVHeads, mHeadSize,
|
||||||
@ -510,7 +516,8 @@ public:
|
|||||||
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
|
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
|
||||||
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
|
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
|
||||||
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
|
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
|
||||||
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
|
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1), mSkipSoftmaxThresholdScaleFactorPrefill,
|
||||||
|
mSkipSoftmaxThresholdScaleFactorDecode);
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -43,7 +43,7 @@ template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_
|
|||||||
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
|
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
|
||||||
{
|
{
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
asm volatile("griddepcontrol.wait;");
|
cudaGridDependencySynchronize();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
|
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
|
||||||
@ -63,7 +63,7 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
asm volatile("griddepcontrol.launch_dependents;");
|
cudaTriggerProgrammaticLaunchCompletion();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
|||||||
{
|
{
|
||||||
return common::getEnvAllReduceWorkspaceSize();
|
return common::getEnvAllReduceWorkspaceSize();
|
||||||
}
|
}
|
||||||
if (worldSize <= 2)
|
char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE");
|
||||||
|
if (envWorkspaceSize != nullptr)
|
||||||
{
|
{
|
||||||
return 16 * 1000 * 1000;
|
return static_cast<size_t>(std::atoi(envWorkspaceSize));
|
||||||
}
|
|
||||||
return 8 * 1000 * 1000;
|
|
||||||
}
|
|
||||||
|
|
||||||
// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold)
|
|
||||||
inline std::unordered_map<int, std::unordered_map<int, std::pair<size_t, size_t>>> HeuristicThresholdLP{
|
|
||||||
{90,
|
|
||||||
{
|
|
||||||
{2, {4096, 4096 * 4096}},
|
|
||||||
{4, {4096, 1024 * 1024}},
|
|
||||||
{8, {2048, 512 * 512}},
|
|
||||||
}},
|
|
||||||
{100,
|
|
||||||
{
|
|
||||||
{2, {4096, 4096 * 4096}},
|
|
||||||
{4, {4096, 1024 * 2048}},
|
|
||||||
{8, {4096, 1024 * 1024}},
|
|
||||||
}},
|
|
||||||
};
|
|
||||||
|
|
||||||
inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op)
|
|
||||||
{
|
|
||||||
// The heuristic is based on the following assumptions:
|
|
||||||
// __________________________________
|
|
||||||
// | \ TWO-SHOT zone |
|
|
||||||
// | ONE-SHOT zone \ | NCCL zone
|
|
||||||
// |_______________________\______|___
|
|
||||||
// sm_major is 90 or 100
|
|
||||||
|
|
||||||
auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion()));
|
|
||||||
|
|
||||||
auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size];
|
|
||||||
auto const message_size = seq_len * hidden_size;
|
|
||||||
if (message_size >= two_shot_numel_threshold)
|
|
||||||
{
|
|
||||||
return AllReduceStrategyType::TWOSHOT;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return AllReduceStrategyType::ONESHOT;
|
|
||||||
}
|
}
|
||||||
|
return 67108864; // 64 MiB
|
||||||
}
|
}
|
||||||
|
|
||||||
// use 1D vector to store the best strategy instead of a map for each sm version
|
// use 1D vector to store the best strategy instead of a map for each sm version
|
||||||
@ -153,7 +115,7 @@ inline AllReduceStrategyType selectStrategyLookUpTable(
|
|||||||
|| num_token_index
|
|| num_token_index
|
||||||
>= AllReduceBestStrategyTable.at(sm_version).at(tp_index).at(fusion_op_index).at(hidden_size_index).size())
|
>= AllReduceBestStrategyTable.at(sm_version).at(tp_index).at(fusion_op_index).at(hidden_size_index).size())
|
||||||
{
|
{
|
||||||
return AllReduceStrategyType::NCCL_SYMMETRIC;
|
return AllReduceStrategyType::NCCL;
|
||||||
}
|
}
|
||||||
|
|
||||||
return static_cast<AllReduceStrategyType>(
|
return static_cast<AllReduceStrategyType>(
|
||||||
@ -164,20 +126,20 @@ inline AllReduceStrategyType selectStrategyLookUpTable(
|
|||||||
inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
||||||
= {{ // TP=2
|
= {{ // TP=2
|
||||||
{// Fusion=NONE
|
{// Fusion=NONE
|
||||||
{4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 0, 0},
|
|
||||||
{4, 4, 5, 4, 4, 5, 5, 5, 4, 5, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 5, 4, 5, 4, 4, 4, 4, 4, 0, 0, 0},
|
|
||||||
{4, 4, 4, 4, 4, 5, 5, 5, 5, 4, 4, 5, 0, 0, 0}, {4, 5, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
|
||||||
{4, 4, 5, 5, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
|
||||||
{// Fusion=RESIDUAL_RMS_NORM
|
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
{4, 4, 4, 4, 5, 5, 5, 5, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 5, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}},
|
||||||
|
{// Fusion=RESIDUAL_RMS_NORM
|
||||||
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||||
{4, 4, 4, 4, 4, 5, 5, 5, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 4, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
{4, 4, 4, 4, 5, 5, 5, 5, 5, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 5, 5, 5, 5, 4, 5, 4, 4, 4, 4, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 5, 5, 4, 4, 4, 5, 4, 4, 4, 0, 0, 0}, {4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||||
{4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
@ -185,20 +147,20 @@ inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
|||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}},
|
||||||
{ // TP=4
|
{ // TP=4
|
||||||
{// Fusion=NONE
|
{// Fusion=NONE
|
||||||
{4, 4, 4, 4, 5, 4, 4, 5, 4, 5, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 5, 4, 5, 5, 5, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 5, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 5, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 5, 5, 5, 5, 4, 5, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 5, 4, 5, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM
|
{// Fusion=RESIDUAL_RMS_NORM
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 5, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 5, 4, 5, 5, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
@ -206,20 +168,20 @@ inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
|||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}},
|
||||||
{ // TP=8
|
{ // TP=8
|
||||||
{// Fusion=NONE
|
{// Fusion=NONE
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM
|
{// Fusion=RESIDUAL_RMS_NORM
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 5, 0, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
@ -229,67 +191,67 @@ inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM90
|
|||||||
inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM100
|
inline AllReduceBestStrategyTableType AllReduceBestStrategyTableSM100
|
||||||
= {{ // TP=2
|
= {{ // TP=2
|
||||||
{// Fusion=NONE
|
{// Fusion=NONE
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||||
{4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0}, {4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM
|
{// Fusion=RESIDUAL_RMS_NORM
|
||||||
{4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 5, 4, 4, 4, 5, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||||
{4, 4, 4, 4, 5, 4, 5, 5, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 5, 5, 4, 4, 4, 4, 4, 4, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||||
{4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 5, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0},
|
||||||
{4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0}}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}}},
|
||||||
{ // TP=4
|
{ // TP=4
|
||||||
{// Fusion=NONE
|
{// Fusion=NONE
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM
|
{// Fusion=RESIDUAL_RMS_NORM
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0}}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0}}},
|
||||||
{ // TP=8
|
{ // TP=8
|
||||||
{// Fusion=NONE
|
{// Fusion=NONE
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM
|
{// Fusion=RESIDUAL_RMS_NORM
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_FP8
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}},
|
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}},
|
||||||
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
{// Fusion=RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 0, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0},
|
{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 0, 0}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0},
|
||||||
{4, 4, 4, 4, 4, 4, 4, 4, 5, 0, 0, 0, 0, 0, 0}}}};
|
{4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0}}}};
|
||||||
|
|
||||||
inline const std::unordered_map<int, AllReduceBestStrategyTableType> AllReduceBestStrategyTable = {
|
inline const std::unordered_map<int, AllReduceBestStrategyTableType> AllReduceBestStrategyTable = {
|
||||||
{90, AllReduceBestStrategyTableSM90},
|
{90, AllReduceBestStrategyTableSM90},
|
||||||
|
|||||||
@ -249,7 +249,7 @@ bool getEnvUseTileSizeKv64ForTrtllmGen()
|
|||||||
bool getEnvEnablePDL()
|
bool getEnvEnablePDL()
|
||||||
{
|
{
|
||||||
static std::once_flag flag;
|
static std::once_flag flag;
|
||||||
static bool enablePDL = false;
|
static bool enablePDL = true;
|
||||||
|
|
||||||
std::call_once(flag,
|
std::call_once(flag,
|
||||||
[&]()
|
[&]()
|
||||||
@ -257,7 +257,18 @@ bool getEnvEnablePDL()
|
|||||||
if (getSMVersion() >= 90)
|
if (getSMVersion() >= 90)
|
||||||
{
|
{
|
||||||
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
||||||
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
|
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
|
||||||
|
if (env)
|
||||||
|
{
|
||||||
|
if (env[0] == '1' && env[1] == '\0')
|
||||||
|
{
|
||||||
|
enablePDL = true;
|
||||||
|
}
|
||||||
|
else if (env[0] == '0' && env[1] == '\0')
|
||||||
|
{
|
||||||
|
enablePDL = false;
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return enablePDL;
|
return enablePDL;
|
||||||
@ -281,6 +292,12 @@ bool getEnvUseNixlKvCache()
|
|||||||
return useNixlKvCache;
|
return useNixlKvCache;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool getEnvUseMooncakeKvCache()
|
||||||
|
{
|
||||||
|
static bool const useMooncakeKvCache = getBoolEnv("TRTLLM_USE_MOONCAKE_KVCACHE");
|
||||||
|
return useMooncakeKvCache;
|
||||||
|
}
|
||||||
|
|
||||||
bool getEnvUseRoundRobinBlockDistForCP()
|
bool getEnvUseRoundRobinBlockDistForCP()
|
||||||
{
|
{
|
||||||
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
|
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
|
||||||
@ -343,6 +360,23 @@ std::string getEnvNixlBackend()
|
|||||||
return nixlBackend;
|
return nixlBackend;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string getEnvMooncakeInterface()
|
||||||
|
{
|
||||||
|
static std::once_flag flag;
|
||||||
|
static std::string mooncakeInterface;
|
||||||
|
|
||||||
|
std::call_once(flag,
|
||||||
|
[&]()
|
||||||
|
{
|
||||||
|
char const* mooncake_interface = std::getenv("TRTLLM_MOONCAKE_INTERFACE");
|
||||||
|
if (mooncake_interface)
|
||||||
|
{
|
||||||
|
mooncakeInterface = mooncake_interface;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return mooncakeInterface;
|
||||||
|
}
|
||||||
|
|
||||||
bool getEnvDisaggLayerwise()
|
bool getEnvDisaggLayerwise()
|
||||||
{
|
{
|
||||||
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
|
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
|
||||||
@ -531,6 +565,11 @@ bool getEnvEplbForceGdrcopy()
|
|||||||
return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY");
|
return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool getEnvPrintSkipSoftmaxStat()
|
||||||
|
{
|
||||||
|
return getBoolEnv("TRTLLM_PRINT_SKIP_SOFTMAX_STAT");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|
||||||
TRTLLM_NAMESPACE_END
|
TRTLLM_NAMESPACE_END
|
||||||
|
|||||||
@ -83,8 +83,11 @@ inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 g
|
|||||||
bool getEnvUseUCXKvCache();
|
bool getEnvUseUCXKvCache();
|
||||||
|
|
||||||
bool getEnvUseMPIKvCache();
|
bool getEnvUseMPIKvCache();
|
||||||
|
|
||||||
bool getEnvUseNixlKvCache();
|
bool getEnvUseNixlKvCache();
|
||||||
|
|
||||||
|
bool getEnvUseMooncakeKvCache();
|
||||||
|
|
||||||
bool getEnvUseRoundRobinBlockDistForCP();
|
bool getEnvUseRoundRobinBlockDistForCP();
|
||||||
|
|
||||||
std::string getEnvUCXInterface();
|
std::string getEnvUCXInterface();
|
||||||
@ -93,6 +96,8 @@ std::string getEnvNixlInterface();
|
|||||||
|
|
||||||
std::string getEnvNixlBackend();
|
std::string getEnvNixlBackend();
|
||||||
|
|
||||||
|
std::string getEnvMooncakeInterface();
|
||||||
|
|
||||||
bool getEnvDisaggLayerwise();
|
bool getEnvDisaggLayerwise();
|
||||||
|
|
||||||
bool getEnvParallelCacheSend();
|
bool getEnvParallelCacheSend();
|
||||||
@ -156,6 +161,8 @@ bool getEnvKVCacheTransferAllBlocksForWindow();
|
|||||||
|
|
||||||
bool getEnvEplbForceGdrcopy();
|
bool getEnvEplbForceGdrcopy();
|
||||||
|
|
||||||
|
bool getEnvPrintSkipSoftmaxStat();
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|
||||||
TRTLLM_NAMESPACE_END
|
TRTLLM_NAMESPACE_END
|
||||||
|
|||||||
226
cpp/tensorrt_llm/common/ipUtils.cpp
Normal file
226
cpp/tensorrt_llm/common/ipUtils.cpp
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ipUtils.h"
|
||||||
|
#include "tensorrt_llm/common/logger.h"
|
||||||
|
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <ifaddrs.h>
|
||||||
|
#include <net/if.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <netinet/in.h>
|
||||||
|
#include <string>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_BEGIN
|
||||||
|
|
||||||
|
namespace common
|
||||||
|
{
|
||||||
|
|
||||||
|
std::string getLocalIpByNic(std::string const& interface, int rank)
|
||||||
|
{
|
||||||
|
struct ifaddrs* ifaddr = nullptr;
|
||||||
|
if (getifaddrs(&ifaddr) == -1)
|
||||||
|
{
|
||||||
|
TLLM_LOG_ERROR(rank,
|
||||||
|
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is "
|
||||||
|
"set "
|
||||||
|
"correctly.");
|
||||||
|
return std::string{};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (struct ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next)
|
||||||
|
{
|
||||||
|
if (ifa->ifa_addr == nullptr)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ifa->ifa_name == interface)
|
||||||
|
{
|
||||||
|
if (ifa->ifa_addr->sa_family == AF_INET)
|
||||||
|
{
|
||||||
|
char ip[INET_ADDRSTRLEN]{};
|
||||||
|
void* addr = &((reinterpret_cast<struct sockaddr_in*>(ifa->ifa_addr))->sin_addr);
|
||||||
|
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "0.0.0.0") != 0)
|
||||||
|
{
|
||||||
|
freeifaddrs(ifaddr);
|
||||||
|
return std::string(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (ifa->ifa_addr->sa_family == AF_INET6)
|
||||||
|
{
|
||||||
|
char ip[INET6_ADDRSTRLEN]{};
|
||||||
|
void* addr = &((reinterpret_cast<struct sockaddr_in6*>(ifa->ifa_addr))->sin6_addr);
|
||||||
|
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
|
||||||
|
&& std::strcmp(ip, "::1") != 0)
|
||||||
|
{
|
||||||
|
freeifaddrs(ifaddr);
|
||||||
|
return std::string(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
freeifaddrs(ifaddr);
|
||||||
|
TLLM_LOG_ERROR(
|
||||||
|
rank, "Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is set correctly.");
|
||||||
|
return std::string{};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getLocalIpByHostname(int rank)
|
||||||
|
{
|
||||||
|
char hostname[256]{};
|
||||||
|
if (gethostname(hostname, sizeof(hostname)) == -1)
|
||||||
|
{
|
||||||
|
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
|
||||||
|
return std::string{};
|
||||||
|
}
|
||||||
|
|
||||||
|
struct addrinfo hints = {};
|
||||||
|
hints.ai_family = AF_UNSPEC;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
hints.ai_flags = AI_CANONNAME;
|
||||||
|
|
||||||
|
struct addrinfo* res = nullptr;
|
||||||
|
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
|
||||||
|
{
|
||||||
|
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
|
||||||
|
return std::string{};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (struct addrinfo* p = res; p != nullptr; p = p->ai_next)
|
||||||
|
{
|
||||||
|
|
||||||
|
if (p->ai_family == AF_INET)
|
||||||
|
{ // IPv4
|
||||||
|
char ip[INET_ADDRSTRLEN]{};
|
||||||
|
struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
|
||||||
|
void* addr = &(ipv4->sin_addr);
|
||||||
|
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "127.0.0.1") != 0
|
||||||
|
&& std::strcmp(ip, "0.0.0.0") != 0)
|
||||||
|
{
|
||||||
|
freeaddrinfo(res);
|
||||||
|
return std::string(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (p->ai_family == AF_INET6)
|
||||||
|
{ // IPv6
|
||||||
|
char ip[INET6_ADDRSTRLEN]{};
|
||||||
|
struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
|
||||||
|
void* addr = &(ipv6->sin6_addr);
|
||||||
|
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
|
||||||
|
&& std::strcmp(ip, "::1") != 0)
|
||||||
|
{
|
||||||
|
freeaddrinfo(res);
|
||||||
|
return std::string(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
freeaddrinfo(res);
|
||||||
|
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
|
||||||
|
return std::string{};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getLocalIpByRemoteOrHostName(int rank)
|
||||||
|
{
|
||||||
|
|
||||||
|
// Try IPv4
|
||||||
|
struct sockaddr_in addr
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
addr.sin_family = AF_INET;
|
||||||
|
addr.sin_port = htons(80);
|
||||||
|
// using google's public dns server to get the local ip which can be accessed from remote
|
||||||
|
char const* dns_ip_v4 = "8.8.8.8";
|
||||||
|
inet_pton(AF_INET, dns_ip_v4, &addr.sin_addr);
|
||||||
|
|
||||||
|
int sock = socket(AF_INET, SOCK_DGRAM, 0);
|
||||||
|
if (sock != -1)
|
||||||
|
{
|
||||||
|
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != -1)
|
||||||
|
{
|
||||||
|
socklen_t addr_len = sizeof(addr);
|
||||||
|
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &addr_len) != -1)
|
||||||
|
{
|
||||||
|
char ip[INET_ADDRSTRLEN]{};
|
||||||
|
inet_ntop(AF_INET, &addr.sin_addr, ip, sizeof(ip));
|
||||||
|
close(sock);
|
||||||
|
return std::string(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try IPv6
|
||||||
|
struct sockaddr_in6 addr6
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
addr6.sin6_family = AF_INET6;
|
||||||
|
addr6.sin6_port = htons(80);
|
||||||
|
// using google's public dns server
|
||||||
|
char const* dns_ipv6 = "2001:4860:4860::8888";
|
||||||
|
inet_pton(AF_INET6, dns_ipv6, &addr6.sin6_addr);
|
||||||
|
|
||||||
|
sock = socket(AF_INET6, SOCK_DGRAM, 0);
|
||||||
|
if (sock != -1)
|
||||||
|
{
|
||||||
|
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr6), sizeof(addr6)) != -1)
|
||||||
|
{
|
||||||
|
socklen_t addr_len = sizeof(addr6);
|
||||||
|
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr6), &addr_len) != -1)
|
||||||
|
{
|
||||||
|
char ip[INET6_ADDRSTRLEN]{};
|
||||||
|
inet_ntop(AF_INET6, &addr6.sin6_addr, ip, sizeof(ip));
|
||||||
|
close(sock);
|
||||||
|
return std::string(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try hostname
|
||||||
|
return getLocalIpByHostname(rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getLocalIp(std::string interface, int rank)
|
||||||
|
{
|
||||||
|
std::string localIP = {};
|
||||||
|
if (!interface.empty())
|
||||||
|
{
|
||||||
|
localIP = getLocalIpByNic(interface, rank);
|
||||||
|
}
|
||||||
|
if (localIP.empty())
|
||||||
|
{
|
||||||
|
localIP = getLocalIpByRemoteOrHostName(rank);
|
||||||
|
}
|
||||||
|
// check whether the localIP is valid
|
||||||
|
if (localIP.empty())
|
||||||
|
{
|
||||||
|
TLLM_THROW("getLocalIp: Can't get local ip");
|
||||||
|
}
|
||||||
|
return localIP;
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_END
|
||||||
28
cpp/tensorrt_llm/common/ipUtils.h
Normal file
28
cpp/tensorrt_llm/common/ipUtils.h
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tensorrt_llm/common/config.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_BEGIN
|
||||||
|
|
||||||
|
namespace common
|
||||||
|
{
|
||||||
|
std::string getLocalIp(std::string interface, int rank);
|
||||||
|
} // namespace common
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_END
|
||||||
@ -37,6 +37,46 @@ NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept
|
|||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NcclCommResourceManager::~NcclCommResourceManager()
|
||||||
|
{
|
||||||
|
// Mark that we're in destruction to prevent cleanup attempts from deleters
|
||||||
|
// that may run during static destruction
|
||||||
|
mIsDestroying.store(true, std::memory_order_release);
|
||||||
|
|
||||||
|
// Proactively clean up all resources before destruction
|
||||||
|
// This ensures cleanup happens in a controlled manner before static destruction
|
||||||
|
std::vector<std::pair<ncclComm_t, std::vector<ResourceEntry>>> allResources;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
// Move all resources out of the map
|
||||||
|
allResources.reserve(mCommResources.size());
|
||||||
|
for (auto& [comm, resources] : mCommResources)
|
||||||
|
{
|
||||||
|
allResources.emplace_back(comm, std::move(resources));
|
||||||
|
}
|
||||||
|
mCommResources.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up all resources outside the lock
|
||||||
|
// Note: We don't call ncclCommDestroy here - that's the responsibility
|
||||||
|
// of the shared_ptr deleter. We just clean up registered resources.
|
||||||
|
for (auto& [comm, resources] : allResources)
|
||||||
|
{
|
||||||
|
for (auto& [cleanup, name] : resources)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
cleanup();
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// Ignore exceptions during destruction
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
|
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
|
||||||
{
|
{
|
||||||
if (!comm)
|
if (!comm)
|
||||||
@ -60,23 +100,56 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we're in the process of being destroyed
|
||||||
|
// If so, skip cleanup - the destructor will handle it proactively
|
||||||
|
if (mIsDestroying.load(std::memory_order_acquire))
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ResourceEntry> resourcesToClean;
|
std::vector<ResourceEntry> resourcesToClean;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mMutex);
|
// During static destruction, mutex and logging may not be safe.
|
||||||
auto it = mCommResources.find(comm);
|
// Use try-catch to handle any issues gracefully.
|
||||||
if (it == mCommResources.end())
|
try
|
||||||
{
|
{
|
||||||
// Nothing registered for this comm, nothing to clean up
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
|
||||||
|
// Double-check after acquiring lock (destruction may have started)
|
||||||
|
if (mIsDestroying.load(std::memory_order_acquire))
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto it = mCommResources.find(comm);
|
||||||
|
if (it == mCommResources.end())
|
||||||
|
{
|
||||||
|
// Nothing registered for this comm, nothing to clean up
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move resources out (preserves order) and remove from map
|
||||||
|
resourcesToClean = std::move(it->second);
|
||||||
|
mCommResources.erase(it);
|
||||||
|
|
||||||
|
// Logging may fail during static destruction, so wrap in try-catch
|
||||||
|
try
|
||||||
|
{
|
||||||
|
TLLM_LOG_TRACE("[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(),
|
||||||
|
static_cast<void*>(comm));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// Ignore logging failures during static destruction
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// If mutex access fails during static destruction, just return.
|
||||||
|
// This prevents segfaults when the singleton is being destroyed.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move resources out (preserves order) and remove from map
|
|
||||||
resourcesToClean = std::move(it->second);
|
|
||||||
mCommResources.erase(it);
|
|
||||||
|
|
||||||
TLLM_LOG_TRACE(
|
|
||||||
"[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast<void*>(comm));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
|
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
|
||||||
@ -85,19 +158,41 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
|
|||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE(
|
// Logging may fail during static destruction, so wrap in try-catch
|
||||||
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
|
try
|
||||||
|
{
|
||||||
|
TLLM_LOG_TRACE(
|
||||||
|
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// Ignore logging failures during static destruction
|
||||||
|
}
|
||||||
cleanup();
|
cleanup();
|
||||||
}
|
}
|
||||||
catch (std::exception const& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(),
|
try
|
||||||
static_cast<void*>(comm), e.what());
|
{
|
||||||
|
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s",
|
||||||
|
name.c_str(), static_cast<void*>(comm), e.what());
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// Ignore logging failures during static destruction
|
||||||
|
}
|
||||||
}
|
}
|
||||||
catch (...)
|
catch (...)
|
||||||
{
|
{
|
||||||
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
|
try
|
||||||
name.c_str(), static_cast<void*>(comm));
|
{
|
||||||
|
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
|
||||||
|
name.c_str(), static_cast<void*>(comm));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// Ignore logging failures during static destruction
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -288,6 +383,21 @@ NCCLWindowBuffer NCCLWindowAllocator::requestBuffer(ncclComm_t comm, size_t size
|
|||||||
return bestFit->buffer;
|
return bestFit->buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// No available buffer found, avoid registration during CUDA graph capture
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
||||||
|
auto capture_err = cudaStreamIsCapturing(stream, &capture_status);
|
||||||
|
if (capture_err != cudaSuccess)
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG("[NCCLUtil] cudaStreamIsCapturing failed: %s", cudaGetErrorString(capture_err));
|
||||||
|
}
|
||||||
|
if (capture_err == cudaSuccess && capture_status != cudaStreamCaptureStatusNone)
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG("[NCCLUtil] Skipping NCCL window allocation during capture for comm %p (requested: %zu)",
|
||||||
|
static_cast<void*>(comm), size);
|
||||||
|
return NCCLWindowBuffer();
|
||||||
|
}
|
||||||
|
|
||||||
// No available buffer found, allocate a new one
|
// No available buffer found, allocate a new one
|
||||||
TLLM_LOG_TRACE(
|
TLLM_LOG_TRACE(
|
||||||
"[NCCLUtil] Allocating new NCCL window buffer for comm %p, size=%zu", static_cast<void*>(comm), size);
|
"[NCCLUtil] Allocating new NCCL window buffer for comm %p, size=%zu", static_cast<void*>(comm), size);
|
||||||
@ -516,8 +626,10 @@ void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept
|
|||||||
// Check for buffers still in use - this shouldn't happen if cleanup is called properly,
|
// Check for buffers still in use - this shouldn't happen if cleanup is called properly,
|
||||||
// but we log a warning if it does
|
// but we log a warning if it does
|
||||||
size_t inUseCount = 0;
|
size_t inUseCount = 0;
|
||||||
|
size_t totalBytes = 0;
|
||||||
for (auto const& entry : commIt->second)
|
for (auto const& entry : commIt->second)
|
||||||
{
|
{
|
||||||
|
totalBytes += entry.buffer.size;
|
||||||
if (entry.inUse)
|
if (entry.inUse)
|
||||||
{
|
{
|
||||||
++inUseCount;
|
++inUseCount;
|
||||||
@ -530,6 +642,8 @@ void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept
|
|||||||
"This may indicate buffers weren't properly released before cleanup.",
|
"This may indicate buffers weren't properly released before cleanup.",
|
||||||
inUseCount, static_cast<void*>(comm));
|
inUseCount, static_cast<void*>(comm));
|
||||||
}
|
}
|
||||||
|
TLLM_LOG_DEBUG("[NCCLUtil] NCCL window allocator teardown for comm %p: %zu buffers, %zu bytes total",
|
||||||
|
static_cast<void*>(comm), commIt->second.size(), totalBytes);
|
||||||
|
|
||||||
for (auto& entry : commIt->second)
|
for (auto& entry : commIt->second)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -21,17 +21,19 @@
|
|||||||
#include "tensorrt_llm/common/logger.h"
|
#include "tensorrt_llm/common/logger.h"
|
||||||
|
|
||||||
#if ENABLE_MULTI_DEVICE
|
#if ENABLE_MULTI_DEVICE
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
#include <nccl.h>
|
#include <nccl.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <atomic>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -139,12 +141,13 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
NcclCommResourceManager() = default;
|
NcclCommResourceManager() = default;
|
||||||
~NcclCommResourceManager() = default;
|
~NcclCommResourceManager();
|
||||||
|
|
||||||
using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;
|
using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;
|
||||||
|
|
||||||
mutable std::mutex mMutex;
|
mutable std::mutex mMutex;
|
||||||
std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
|
std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
|
||||||
|
std::atomic<bool> mIsDestroying{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
// RAII helper to register a resource with a NCCL communicator.
|
// RAII helper to register a resource with a NCCL communicator.
|
||||||
@ -375,15 +378,23 @@ inline std::pair<torch::Tensor, NCCLWindowBuffer> createNCCLWindowTensor(
|
|||||||
|
|
||||||
// Request buffer from allocator
|
// Request buffer from allocator
|
||||||
auto& allocator = NCCLWindowAllocator::getInstance();
|
auto& allocator = NCCLWindowAllocator::getInstance();
|
||||||
auto buffer = allocator.requestBuffer(comm, buffer_size);
|
NCCLWindowBuffer buffer;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
buffer = allocator.requestBuffer(comm, buffer_size);
|
||||||
|
}
|
||||||
|
catch (std::exception const& e)
|
||||||
|
{
|
||||||
|
TLLM_LOG_DEBUG("[createNCCLWindowTensor] requestBuffer failed; returning invalid buffer: %s", e.what());
|
||||||
|
return std::make_pair(torch::Tensor(), NCCLWindowBuffer());
|
||||||
|
}
|
||||||
|
|
||||||
// Defensive validation: ensure buffer is valid before proceeding
|
// Defensive validation: ensure buffer is valid before proceeding
|
||||||
if (!buffer.isValid())
|
if (!buffer.isValid())
|
||||||
{
|
{
|
||||||
std::ostringstream oss;
|
TLLM_LOG_DEBUG("[createNCCLWindowTensor] invalid buffer returned from requestBuffer; returning invalid buffer");
|
||||||
oss << "Failed to allocate NCCL window buffer: invalid buffer returned from requestBuffer "
|
return std::make_pair(torch::Tensor(), NCCLWindowBuffer());
|
||||||
<< "(comm=" << static_cast<void*>(comm) << ", buffer_size=" << buffer_size << ")";
|
|
||||||
throw std::runtime_error(oss.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create custom deleter that releases the buffer
|
// Create custom deleter that releases the buffer
|
||||||
|
|||||||
184
cpp/tensorrt_llm/common/nvmlWrapper.cpp
Normal file
184
cpp/tensorrt_llm/common/nvmlWrapper.cpp
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
|
#include "tensorrt_llm/common/assert.h"
|
||||||
|
#include "tensorrt_llm/common/config.h"
|
||||||
|
#include "tensorrt_llm/common/logger.h"
|
||||||
|
#include "tensorrt_llm/common/nvmlWrapper.h"
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_BEGIN
|
||||||
|
|
||||||
|
namespace common
|
||||||
|
{
|
||||||
|
|
||||||
|
std::shared_ptr<NVMLWrapper> NVMLWrapper::getInstance()
|
||||||
|
{
|
||||||
|
static std::mutex mutex;
|
||||||
|
static std::weak_ptr<NVMLWrapper> instance;
|
||||||
|
std::shared_ptr<NVMLWrapper> result = instance.lock();
|
||||||
|
if (result)
|
||||||
|
{
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> const lock(mutex);
|
||||||
|
result = instance.lock();
|
||||||
|
if (!result)
|
||||||
|
{
|
||||||
|
result = std::shared_ptr<NVMLWrapper>(new NVMLWrapper());
|
||||||
|
instance = result;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
NVMLWrapper::NVMLWrapper()
|
||||||
|
: mHandle(dlopen("libnvidia-ml.so.1", RTLD_LAZY))
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(mHandle != nullptr, "NVML library (libnvidia-ml.so.1) could not be loaded.");
|
||||||
|
|
||||||
|
auto loadSym = [](void* handle, char const* name) -> void* { return dlsym(handle, name); };
|
||||||
|
|
||||||
|
auto loadRequired = [&](void* handle, char const* name) -> void*
|
||||||
|
{
|
||||||
|
void* sym = loadSym(handle, name);
|
||||||
|
TLLM_CHECK_WITH_INFO(sym != nullptr, "Required NVML symbol not found: %s", name);
|
||||||
|
return sym;
|
||||||
|
};
|
||||||
|
|
||||||
|
*reinterpret_cast<void**>(&_nvmlInit) = loadRequired(mHandle, "nvmlInit_v2");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlShutdown) = loadRequired(mHandle, "nvmlShutdown");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetHandleByIndex) = loadRequired(mHandle, "nvmlDeviceGetHandleByIndex_v2");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetHandleByPciBusId)
|
||||||
|
= loadRequired(mHandle, "nvmlDeviceGetHandleByPciBusId_v2");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetIndex) = loadRequired(mHandle, "nvmlDeviceGetIndex");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetNvLinkRemotePciInfo)
|
||||||
|
= loadRequired(mHandle, "nvmlDeviceGetNvLinkRemotePciInfo_v2");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetNvLinkCapability) = loadRequired(mHandle, "nvmlDeviceGetNvLinkCapability");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetNvLinkState) = loadRequired(mHandle, "nvmlDeviceGetNvLinkState");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlErrorString) = loadRequired(mHandle, "nvmlErrorString");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetComputeRunningProcesses)
|
||||||
|
= loadRequired(mHandle, "nvmlDeviceGetComputeRunningProcesses_v3");
|
||||||
|
|
||||||
|
// Optional symbols - nullptr is OK (older drivers may not have these)
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetGpuFabricInfoV) = loadSym(mHandle, "nvmlDeviceGetGpuFabricInfoV");
|
||||||
|
*reinterpret_cast<void**>(&_nvmlDeviceGetGpuFabricInfo) = loadSym(mHandle, "nvmlDeviceGetGpuFabricInfo");
|
||||||
|
|
||||||
|
if (!_nvmlDeviceGetGpuFabricInfoV)
|
||||||
|
{
|
||||||
|
TLLM_LOG_INFO(
|
||||||
|
"NVML symbol nvmlDeviceGetGpuFabricInfoV not available (older driver). MNNVL fabric detection will use "
|
||||||
|
"legacy API or be disabled.");
|
||||||
|
}
|
||||||
|
if (!_nvmlDeviceGetGpuFabricInfo)
|
||||||
|
{
|
||||||
|
TLLM_LOG_INFO("NVML symbol nvmlDeviceGetGpuFabricInfo not available.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NVMLWrapper::~NVMLWrapper()
|
||||||
|
{
|
||||||
|
dlclose(mHandle);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlInit() const
|
||||||
|
{
|
||||||
|
return (*_nvmlInit)();
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlShutdown() const
|
||||||
|
{
|
||||||
|
return (*_nvmlShutdown)();
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t* device) const
|
||||||
|
{
|
||||||
|
return (*_nvmlDeviceGetHandleByIndex)(index, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetHandleByPciBusId(char const* pciBusId, nvmlDevice_t* device) const
|
||||||
|
{
|
||||||
|
return (*_nvmlDeviceGetHandleByPciBusId)(pciBusId, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetIndex(nvmlDevice_t device, unsigned int* index) const
|
||||||
|
{
|
||||||
|
return (*_nvmlDeviceGetIndex)(device, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetNvLinkRemotePciInfo(
|
||||||
|
nvmlDevice_t device, unsigned int link, nvmlPciInfo_t* pci) const
|
||||||
|
{
|
||||||
|
return (*_nvmlDeviceGetNvLinkRemotePciInfo)(device, link, pci);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetNvLinkCapability(
|
||||||
|
nvmlDevice_t device, unsigned int link, nvmlNvLinkCapability_t capability, unsigned int* capResult) const
|
||||||
|
{
|
||||||
|
return (*_nvmlDeviceGetNvLinkCapability)(device, link, capability, capResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetNvLinkState(
|
||||||
|
nvmlDevice_t device, unsigned int link, nvmlEnableState_t* isActive) const
|
||||||
|
{
|
||||||
|
return (*_nvmlDeviceGetNvLinkState)(device, link, isActive);
|
||||||
|
}
|
||||||
|
|
||||||
|
char const* NVMLWrapper::nvmlErrorString(nvmlReturn_t result) const
|
||||||
|
{
|
||||||
|
return (*_nvmlErrorString)(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetGpuFabricInfoV(nvmlDevice_t device, nvmlGpuFabricInfoV_t* gpuFabricInfo) const
|
||||||
|
{
|
||||||
|
if (!_nvmlDeviceGetGpuFabricInfoV)
|
||||||
|
{
|
||||||
|
return NVML_ERROR_FUNCTION_NOT_FOUND;
|
||||||
|
}
|
||||||
|
return (*_nvmlDeviceGetGpuFabricInfoV)(device, gpuFabricInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetGpuFabricInfo(nvmlDevice_t device, nvmlGpuFabricInfo_t* gpuFabricInfo) const
|
||||||
|
{
|
||||||
|
if (!_nvmlDeviceGetGpuFabricInfo)
|
||||||
|
{
|
||||||
|
return NVML_ERROR_FUNCTION_NOT_FOUND;
|
||||||
|
}
|
||||||
|
return (*_nvmlDeviceGetGpuFabricInfo)(device, gpuFabricInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvmlReturn_t NVMLWrapper::nvmlDeviceGetComputeRunningProcesses(
|
||||||
|
nvmlDevice_t device, unsigned int* infoCount, nvmlProcessInfo_v2_t* infos) const
|
||||||
|
{
|
||||||
|
return (*_nvmlDeviceGetComputeRunningProcesses)(device, infoCount, infos);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NVMLWrapper::hasGpuFabricInfoV() const
|
||||||
|
{
|
||||||
|
return _nvmlDeviceGetGpuFabricInfoV != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NVMLWrapper::hasGpuFabricInfo() const
|
||||||
|
{
|
||||||
|
return _nvmlDeviceGetGpuFabricInfo != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_END
|
||||||
123
cpp/tensorrt_llm/common/nvmlWrapper.h
Normal file
123
cpp/tensorrt_llm/common/nvmlWrapper.h
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef NVML_WRAPPER_H
|
||||||
|
#define NVML_WRAPPER_H
|
||||||
|
|
||||||
|
#include "tensorrt_llm/common/config.h"
|
||||||
|
|
||||||
|
#include <nvml.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_BEGIN
|
||||||
|
|
||||||
|
namespace common
|
||||||
|
{
|
||||||
|
|
||||||
|
class NVMLWrapper
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static std::shared_ptr<NVMLWrapper> getInstance();
|
||||||
|
|
||||||
|
~NVMLWrapper();
|
||||||
|
NVMLWrapper(NVMLWrapper const&) = delete;
|
||||||
|
NVMLWrapper& operator=(NVMLWrapper const&) = delete;
|
||||||
|
NVMLWrapper(NVMLWrapper&&) = delete;
|
||||||
|
NVMLWrapper& operator=(NVMLWrapper&&) = delete;
|
||||||
|
|
||||||
|
// Required NVML functions
|
||||||
|
nvmlReturn_t nvmlInit() const;
|
||||||
|
nvmlReturn_t nvmlShutdown() const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t* device) const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetHandleByPciBusId(char const* pciBusId, nvmlDevice_t* device) const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetIndex(nvmlDevice_t device, unsigned int* index) const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetNvLinkRemotePciInfo(nvmlDevice_t device, unsigned int link, nvmlPciInfo_t* pci) const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetNvLinkCapability(
|
||||||
|
nvmlDevice_t device, unsigned int link, nvmlNvLinkCapability_t capability, unsigned int* capResult) const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetNvLinkState(nvmlDevice_t device, unsigned int link, nvmlEnableState_t* isActive) const;
|
||||||
|
char const* nvmlErrorString(nvmlReturn_t result) const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetComputeRunningProcesses(
|
||||||
|
nvmlDevice_t device, unsigned int* infoCount, nvmlProcessInfo_v2_t* infos) const;
|
||||||
|
|
||||||
|
// Optional NVML functions (may be nullptr on older drivers)
|
||||||
|
nvmlReturn_t nvmlDeviceGetGpuFabricInfoV(nvmlDevice_t device, nvmlGpuFabricInfoV_t* gpuFabricInfo) const;
|
||||||
|
nvmlReturn_t nvmlDeviceGetGpuFabricInfo(nvmlDevice_t device, nvmlGpuFabricInfo_t* gpuFabricInfo) const;
|
||||||
|
|
||||||
|
// Runtime availability checks
|
||||||
|
bool hasGpuFabricInfoV() const;
|
||||||
|
bool hasGpuFabricInfo() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void* mHandle;
|
||||||
|
NVMLWrapper();
|
||||||
|
|
||||||
|
// Required function pointers
|
||||||
|
nvmlReturn_t (*_nvmlInit)();
|
||||||
|
nvmlReturn_t (*_nvmlShutdown)();
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t*);
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetHandleByPciBusId)(char const*, nvmlDevice_t*);
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetIndex)(nvmlDevice_t, unsigned int*);
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetNvLinkRemotePciInfo)(nvmlDevice_t, unsigned int, nvmlPciInfo_t*);
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetNvLinkCapability)(nvmlDevice_t, unsigned int, nvmlNvLinkCapability_t, unsigned int*);
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetNvLinkState)(nvmlDevice_t, unsigned int, nvmlEnableState_t*);
|
||||||
|
char const* (*_nvmlErrorString)(nvmlReturn_t);
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetComputeRunningProcesses)(nvmlDevice_t, unsigned int*, nvmlProcessInfo_v2_t*);
|
||||||
|
|
||||||
|
// Optional function pointers (may be nullptr)
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetGpuFabricInfoV)(nvmlDevice_t, nvmlGpuFabricInfoV_t*);
|
||||||
|
nvmlReturn_t (*_nvmlDeviceGetGpuFabricInfo)(nvmlDevice_t, nvmlGpuFabricInfo_t*);
|
||||||
|
};
|
||||||
|
|
||||||
|
// RAII class that initializes NVML on construction and shuts down on destruction.
|
||||||
|
// Replaces duplicated NvmlManager classes in allreduceOp.cpp and allreducePlugin.cpp.
|
||||||
|
class NvmlManager
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NvmlManager()
|
||||||
|
: mNvml(NVMLWrapper::getInstance())
|
||||||
|
{
|
||||||
|
auto result = mNvml->nvmlInit();
|
||||||
|
if (result != NVML_SUCCESS)
|
||||||
|
{
|
||||||
|
TLLM_THROW("Failed to initialize NVML: %s", mNvml->nvmlErrorString(result));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~NvmlManager()
|
||||||
|
{
|
||||||
|
mNvml->nvmlShutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
NVMLWrapper const& wrapper() const
|
||||||
|
{
|
||||||
|
return *mNvml;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<NVMLWrapper> const& sharedWrapper() const
|
||||||
|
{
|
||||||
|
return mNvml;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<NVMLWrapper> mNvml;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
|
||||||
|
TRTLLM_NAMESPACE_END
|
||||||
|
|
||||||
|
#endif // NVML_WRAPPER_H
|
||||||
@ -123,13 +123,24 @@ std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
|
|||||||
if (*comm)
|
if (*comm)
|
||||||
{
|
{
|
||||||
// Clean up all registered resources FIRST
|
// Clean up all registered resources FIRST
|
||||||
|
// The cleanupResources function uses a destruction guard to safely handle
|
||||||
|
// static destruction order issues - it will return early if the singleton
|
||||||
|
// is being destroyed (in which case the destructor handles cleanup proactively)
|
||||||
tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm);
|
tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm);
|
||||||
|
|
||||||
// Now destroy the NCCL communicator
|
// Now destroy the NCCL communicator
|
||||||
ncclResult_t result = ncclCommDestroy(*comm);
|
ncclResult_t result = ncclCommDestroy(*comm);
|
||||||
if (result != ncclSuccess)
|
if (result != ncclSuccess)
|
||||||
{
|
{
|
||||||
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
|
// Logging may fail during static destruction, so wrap in try-catch
|
||||||
|
try
|
||||||
|
{
|
||||||
|
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// Ignore logging failures during static destruction
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear the communicator value before freeing the pointer
|
// Clear the communicator value before freeing the pointer
|
||||||
|
|||||||
@ -38,6 +38,8 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorrt_llm/common/nvmlWrapper.h"
|
||||||
|
|
||||||
TRTLLM_NAMESPACE_BEGIN
|
TRTLLM_NAMESPACE_BEGIN
|
||||||
|
|
||||||
namespace common::op
|
namespace common::op
|
||||||
@ -319,7 +321,8 @@ TRTLLM_NAMESPACE_END
|
|||||||
nvmlReturn_t r = cmd; \
|
nvmlReturn_t r = cmd; \
|
||||||
if (r != NVML_SUCCESS) \
|
if (r != NVML_SUCCESS) \
|
||||||
{ \
|
{ \
|
||||||
printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
|
printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||||
|
tensorrt_llm::common::NVMLWrapper::getInstance()->nvmlErrorString(r)); \
|
||||||
exit(EXIT_FAILURE); \
|
exit(EXIT_FAILURE); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
@ -330,6 +333,7 @@ TRTLLM_NAMESPACE_END
|
|||||||
nvmlReturn_t r = cmd; \
|
nvmlReturn_t r = cmd; \
|
||||||
if (TLLM_UNLIKELY(r != NVML_SUCCESS)) \
|
if (TLLM_UNLIKELY(r != NVML_SUCCESS)) \
|
||||||
{ \
|
{ \
|
||||||
TLLM_THROW("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
|
TLLM_THROW("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||||
|
tensorrt_llm::common::NVMLWrapper::getInstance()->nvmlErrorString(r)); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|||||||
@ -46,7 +46,7 @@ CUTLASS_DEVICE
|
|||||||
void launch_dependent_grids()
|
void launch_dependent_grids()
|
||||||
{
|
{
|
||||||
#if (defined(CUTLASS_GDC_ENABLED))
|
#if (defined(CUTLASS_GDC_ENABLED))
|
||||||
asm volatile("griddepcontrol.launch_dependents;");
|
cudaTriggerProgrammaticLaunchCompletion();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ CUTLASS_DEVICE
|
|||||||
void wait_on_dependent_grids()
|
void wait_on_dependent_grids()
|
||||||
{
|
{
|
||||||
#if (defined(CUTLASS_GDC_ENABLED))
|
#if (defined(CUTLASS_GDC_ENABLED))
|
||||||
asm volatile("griddepcontrol.wait;");
|
cudaGridDependencySynchronize();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -686,4 +686,212 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class Collective>
|
||||||
|
struct MixedInputUtilsSM100
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
using KernelSchedule = typename Collective::KernelSchedule;
|
||||||
|
using ConversionMode = typename Collective::ConversionMode;
|
||||||
|
using SmemLayoutA = typename Collective::SmemLayoutA;
|
||||||
|
using SmemLayoutB = typename Collective::SmemLayoutB;
|
||||||
|
using ElementScale = typename Collective::ElementScale;
|
||||||
|
using ElementZero = typename Collective::ElementZero;
|
||||||
|
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Helper functions to select packing for conversion
|
||||||
|
template <class SrcType, class DstType, int Cosize>
|
||||||
|
struct select_packing
|
||||||
|
{ // Naive packing policy
|
||||||
|
|
||||||
|
static constexpr auto value()
|
||||||
|
{
|
||||||
|
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// (Designed for separate transform pipeline in Blackwell)
|
||||||
|
/// Utilities to dequantize A.
|
||||||
|
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||||
|
CUTLASS_DEVICE static void dequantize_A_kblock_for_transform(Tensor<EngineIn, LayoutIn> const& tArA,
|
||||||
|
Tensor<EngineOut, LayoutOut>& tArACompute, cute::tuple<Ts...> const& partitioned_extra_info, int const k_block)
|
||||||
|
{
|
||||||
|
|
||||||
|
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||||
|
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||||
|
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||||
|
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||||
|
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||||
|
using SrcType = typename EngineIn::value_type;
|
||||||
|
using DstType = typename EngineOut::value_type;
|
||||||
|
|
||||||
|
auto src = tArA(_, _, _, k_block);
|
||||||
|
auto dst = tArACompute(_, _, _, k_block);
|
||||||
|
auto pSrc = raw_pointer_cast(src.data());
|
||||||
|
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
|
||||||
|
constexpr int num_elements = decltype(size(src))::value;
|
||||||
|
|
||||||
|
constexpr int pack = decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
|
||||||
|
using Converter
|
||||||
|
= cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||||
|
using DstArray = cutlass::Array<DstType, pack>;
|
||||||
|
constexpr int DstElementsPerReg = 32 / sizeof_bits_v<DstType>;
|
||||||
|
using RegArray = cutlass::AlignedArray<uint32_t, pack / DstElementsPerReg, sizeof(DstArray)>;
|
||||||
|
|
||||||
|
auto src_arr = recast<SrcArray>(src);
|
||||||
|
auto dst_arr = recast<DstArray>(dst);
|
||||||
|
|
||||||
|
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, pack));
|
||||||
|
|
||||||
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert)
|
||||||
|
{
|
||||||
|
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||||
|
}
|
||||||
|
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
|
||||||
|
{
|
||||||
|
|
||||||
|
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
|
||||||
|
|
||||||
|
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||||
|
|
||||||
|
if constexpr (is_same_v<DstType, ElementScale>)
|
||||||
|
{
|
||||||
|
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||||
|
|
||||||
|
using ScaleArray = cutlass::Array<ElementScale, pack>;
|
||||||
|
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
|
||||||
|
|
||||||
|
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
|
||||||
|
{
|
||||||
|
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
|
||||||
|
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||||
|
{
|
||||||
|
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
|
||||||
|
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
|
||||||
|
{
|
||||||
|
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||||
|
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
|
||||||
|
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
|
||||||
|
constexpr int pack = cute::gcd(pack1, pack2);
|
||||||
|
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||||
|
using DstArray = cutlass::Array<DstType, pack>;
|
||||||
|
using StageArray = cutlass::Array<ElementScale, pack>;
|
||||||
|
constexpr int iters = num_elements / pack;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < iters; ++i)
|
||||||
|
{
|
||||||
|
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
|
||||||
|
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
|
||||||
|
StageArray stageArr;
|
||||||
|
stageArr = Converter1::convert(*pSrcArr);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < pack; ++j)
|
||||||
|
{
|
||||||
|
stageArr[j] = stageArr[j] * scales[i * pack + j];
|
||||||
|
}
|
||||||
|
*pDstArr = Converter2::convert(stageArr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
|
||||||
|
{
|
||||||
|
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
|
||||||
|
|
||||||
|
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
|
||||||
|
auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, _, k_block);
|
||||||
|
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||||
|
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
|
||||||
|
|
||||||
|
if constexpr (is_same_v<DstType, ElementZero>)
|
||||||
|
{
|
||||||
|
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||||
|
|
||||||
|
using ScaleArray = cutlass::Array<ElementScale, pack>;
|
||||||
|
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
|
||||||
|
|
||||||
|
using ZeroArray = cutlass::Array<ElementZero, pack>;
|
||||||
|
auto zero_arr = recast<ZeroArray>(filter_zeros(zeros));
|
||||||
|
|
||||||
|
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
|
||||||
|
{
|
||||||
|
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
|
||||||
|
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, pack));
|
||||||
|
|
||||||
|
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||||
|
{
|
||||||
|
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
|
||||||
|
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
|
||||||
|
auto&& zero_reg = cute::recast<RegArray>(zeros_vm(_, i))(0);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
|
||||||
|
{
|
||||||
|
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||||
|
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
|
||||||
|
bf16x2_val = __hadd2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(zero_reg[ii]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
|
||||||
|
cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
|
||||||
|
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
|
||||||
|
constexpr int pack = cute::gcd(pack1, pack2);
|
||||||
|
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||||
|
using DstArray = cutlass::Array<DstType, pack>;
|
||||||
|
using StageArray = cutlass::Array<ElementScale, pack>;
|
||||||
|
constexpr int iters = num_elements / pack;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < iters; ++i)
|
||||||
|
{
|
||||||
|
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
|
||||||
|
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
|
||||||
|
StageArray stageArr;
|
||||||
|
stageArr = Converter1::convert(*pSrcArr);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < pack; ++j)
|
||||||
|
{
|
||||||
|
stageArr[j] = stageArr[j] * scales[i * pack + j] + zeros[i * pack + j];
|
||||||
|
}
|
||||||
|
*pDstArr = Converter2::convert(stageArr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
|
||||||
|
"Conversion mode not handled for input partitioning.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace cutlass::gemm::collective::detail
|
} // namespace cutlass::gemm::collective::detail
|
||||||
|
|||||||
@ -0,0 +1,294 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective
|
||||||
|
{
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace detail
|
||||||
|
{
|
||||||
|
|
||||||
|
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||||
|
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
|
||||||
|
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
|
||||||
|
int stages>
|
||||||
|
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(StageCount<stages> stage_count)
|
||||||
|
{
|
||||||
|
constexpr int Load2TransformStageCount = stages;
|
||||||
|
constexpr int Transform2MmaStageCount = stages;
|
||||||
|
constexpr int AccumulatorStageCount = stages;
|
||||||
|
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
|
||||||
|
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
|
||||||
|
int carveout_bytes>
|
||||||
|
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(
|
||||||
|
StageCountAutoCarveout<carveout_bytes> stage_count)
|
||||||
|
{
|
||||||
|
|
||||||
|
constexpr int CtaM = get<0>(CtaTileShape_MNK{});
|
||||||
|
constexpr int CtaN = get<1>(CtaTileShape_MNK{});
|
||||||
|
static_assert(CtaN <= 128, "Can't support CtaN>128 tiles");
|
||||||
|
constexpr int CtaK = get<2>(CtaTileShape_MNK{});
|
||||||
|
using AtomThrID = typename TiledMma::AtomThrID;
|
||||||
|
|
||||||
|
constexpr int TmemColumns = 512;
|
||||||
|
|
||||||
|
constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K
|
||||||
|
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>;
|
||||||
|
constexpr bool IsAComputeinSmem = !IsAComputeinTmem;
|
||||||
|
|
||||||
|
// Detect 2x2 TMEM layout
|
||||||
|
constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN / 2 : CtaN;
|
||||||
|
constexpr int TmemAWordsPerDP = CtaK / 2;
|
||||||
|
|
||||||
|
constexpr int AccumulatorStageCount
|
||||||
|
= (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP);
|
||||||
|
|
||||||
|
constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32);
|
||||||
|
|
||||||
|
constexpr int TmemInAStageCount_Potential
|
||||||
|
= (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000;
|
||||||
|
|
||||||
|
// Mainload2Transform Pipeline
|
||||||
|
constexpr auto load2transform_pipeline_bytes
|
||||||
|
= sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage);
|
||||||
|
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>; // ElementA introduce here
|
||||||
|
constexpr auto s_bits = cute::is_void_v<ElementScale> ? 0 : cute::sizeof_bits_v<ElementScale>;
|
||||||
|
constexpr auto z_bits = cute::is_void_v<ElementZero> ? 0 : cute::sizeof_bits_v<ElementZero>;
|
||||||
|
|
||||||
|
constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage);
|
||||||
|
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>; // ElementB introduce here
|
||||||
|
|
||||||
|
constexpr int ab_stage_bytes
|
||||||
|
= cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
|
||||||
|
+ cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
|
||||||
|
+ cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
|
||||||
|
+ cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{}))
|
||||||
|
+ static_cast<int>(load2transform_pipeline_bytes) + static_cast<int>(load2mma_pipeline_bytes);
|
||||||
|
|
||||||
|
// Transform2Mma Pipeline
|
||||||
|
constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
|
||||||
|
constexpr auto a_compute_bits = cute::sizeof_bits_v<ElementAMma>;
|
||||||
|
constexpr int ab_compute_stage_bytes = cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem)
|
||||||
|
* size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
|
||||||
|
+ // If ACompute is in TMEM, Acompute buffer has 0 bytes.
|
||||||
|
static_cast<int>(transform2mma_pipeline_bytes);
|
||||||
|
|
||||||
|
constexpr int ABComputeStageCount_Potential
|
||||||
|
= SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes);
|
||||||
|
|
||||||
|
// The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount
|
||||||
|
constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential);
|
||||||
|
|
||||||
|
constexpr int SmemCapacityAfterABComputeCarveout
|
||||||
|
= SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes);
|
||||||
|
|
||||||
|
// Can we boost the number of buffers for A and B?
|
||||||
|
constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes;
|
||||||
|
|
||||||
|
static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2,
|
||||||
|
"Not enough SMEM or TMEM capacity for selected tile size");
|
||||||
|
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// Mixed Input MMA kernels builder
|
||||||
|
template <class ElementAOptionalTuple, class GmemLayoutATagTuple, int AlignmentA, class ElementBOptionalTuple,
|
||||||
|
class GmemLayoutBTag, int AlignmentB, class ElementAccumulator,
|
||||||
|
class TileShape_MNK, // The Cluster-level TileShape
|
||||||
|
class ClusterShape_MNK, class StageCountType, class KernelScheduleType>
|
||||||
|
struct CollectiveBuilderSm100WeightOnly<arch::Sm100, arch::OpClassTensorOp,
|
||||||
|
ElementAOptionalTuple, // ElementA
|
||||||
|
GmemLayoutATagTuple, // LayoutA
|
||||||
|
AlignmentA,
|
||||||
|
ElementBOptionalTuple, // ElementB
|
||||||
|
GmemLayoutBTag, // LayoutB
|
||||||
|
AlignmentB, ElementAccumulator,
|
||||||
|
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||||
|
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int)
|
||||||
|
StageCountType, KernelScheduleType,
|
||||||
|
cute::enable_if_t<(cute::is_base_of_v<KernelScheduleSm100MixedInputGemm, KernelScheduleType>) &&(
|
||||||
|
(sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0)
|
||||||
|
&& ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>>
|
||||||
|
{
|
||||||
|
using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>;
|
||||||
|
using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>;
|
||||||
|
|
||||||
|
static constexpr cute::UMMA::Major UmmaMajorA
|
||||||
|
= cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
|
||||||
|
static constexpr cute::UMMA::Major UmmaMajorB
|
||||||
|
= cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
|
||||||
|
|
||||||
|
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>;
|
||||||
|
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>;
|
||||||
|
using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>;
|
||||||
|
using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>;
|
||||||
|
|
||||||
|
static constexpr bool NeitherIsTuple
|
||||||
|
= !cute::is_tuple<ElementAOptionalTuple>::value && !cute::is_tuple<ElementBOptionalTuple>::value;
|
||||||
|
static constexpr bool IsANarrow = cute::sizeof_bits_v<ElementA> < cute::sizeof_bits_v<ElementB>;
|
||||||
|
static constexpr bool IsMixedInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
|
||||||
|
static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm.");
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
(cute::is_tuple<ElementAOptionalTuple>::value ^ cute::is_tuple<ElementBOptionalTuple>::value
|
||||||
|
|| (NeitherIsTuple && (cute::sizeof_bits<ElementA>::value != cute::sizeof_bits<ElementB>::value))),
|
||||||
|
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||||
|
using ElementPairA = cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>,
|
||||||
|
ElementAOptionalTuple>;
|
||||||
|
using ElementPairB = cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>,
|
||||||
|
ElementBOptionalTuple>;
|
||||||
|
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||||
|
static_assert(IsATransformed, "A matrix should be transformed.");
|
||||||
|
|
||||||
|
// For fp32 types, map to tf32 MMA value type.
|
||||||
|
using ElementMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||||
|
|
||||||
|
using ElementAMma = ElementMma;
|
||||||
|
using ElementBMma = ElementMma;
|
||||||
|
|
||||||
|
static constexpr int IsSubbyteA = cute::sizeof_bits_v<ElementA> < 8;
|
||||||
|
using TmaElementA = cute::conditional_t<IsSubbyteA, uint8_t, ElementA>;
|
||||||
|
|
||||||
|
static constexpr int ScalingFactor = 1;
|
||||||
|
|
||||||
|
using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma<ElementAMma, ElementB,
|
||||||
|
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB, KernelScheduleType>());
|
||||||
|
using AtomThrID = typename TiledMma::AtomThrID;
|
||||||
|
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
|
||||||
|
using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
|
||||||
|
|
||||||
|
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
|
||||||
|
using MmaShapeA_MK = decltype(partition_shape_A(
|
||||||
|
TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
|
||||||
|
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
|
||||||
|
using MmaShapeB_NK = decltype(partition_shape_B(
|
||||||
|
TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
|
||||||
|
|
||||||
|
using BlockTileA_M = decltype(cute::size<0, 0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
|
||||||
|
using BlockTileA_K = decltype(cute::size<0, 1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
|
||||||
|
|
||||||
|
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{})));
|
||||||
|
using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}));
|
||||||
|
|
||||||
|
// Input transform kernel can not use TMA 2SM instructions.
|
||||||
|
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA, ElementA,
|
||||||
|
BlockTileA_M, BlockTileA_K>());
|
||||||
|
using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA,
|
||||||
|
ElementAMma, BlockTileA_M, BlockTileA_K>());
|
||||||
|
using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomA,
|
||||||
|
SmemLayoutAtomACompute>;
|
||||||
|
static constexpr int MMA_M = cute::size<0, 0>(MmaShapeA_MK{});
|
||||||
|
using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
|
||||||
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>,
|
||||||
|
cute::conditional_t<
|
||||||
|
(UmmaMajorA == cute::UMMA::Major::K
|
||||||
|
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>),
|
||||||
|
cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x,
|
||||||
|
SM100_TMEM_STORE_32dp32b8x>, // TS Implementation
|
||||||
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>> // SS Implementation
|
||||||
|
>;
|
||||||
|
|
||||||
|
using BlockTileB_N = decltype(cute::size<0, 0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
|
||||||
|
using BlockTileB_K = decltype(cute::size<0, 1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
|
||||||
|
|
||||||
|
// Input transform kernel can not use TMA 2SM instructions.
|
||||||
|
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB, ElementB,
|
||||||
|
BlockTileB_N, BlockTileB_K>());
|
||||||
|
using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB,
|
||||||
|
ElementBMma, BlockTileB_N, BlockTileB_K>());
|
||||||
|
using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomB,
|
||||||
|
SmemLayoutAtomBCompute>;
|
||||||
|
using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
|
||||||
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementB>,
|
||||||
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementMma>>;
|
||||||
|
|
||||||
|
// Creating the stride of Transformed Input
|
||||||
|
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
|
||||||
|
using LayoutScale = cutlass::gemm::TagToStrideA_t<GmemLayoutScaleTag>;
|
||||||
|
|
||||||
|
using VoidShapeScale
|
||||||
|
= Shape<Shape<Int<128>, _1>, Shape<Int<64>, _1>, _1>; // Dummy Value to create a dummy ScaleConfig
|
||||||
|
using VoidStrideScale = Stride<Stride<_0, _1>, Stride<_0, _1>, _1>;
|
||||||
|
using VoidLayoutScale = Layout<VoidShapeScale, VoidStrideScale>;
|
||||||
|
|
||||||
|
using NonVoidLayoutScale = cute::conditional_t<cute::is_void_v<LayoutScale>, VoidLayoutScale, LayoutScale>;
|
||||||
|
|
||||||
|
using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{}));
|
||||||
|
|
||||||
|
// SmemCarveout
|
||||||
|
static constexpr int SchedulerPipelineStageCount = 3;
|
||||||
|
static constexpr bool IsArrayOfPointersGemm
|
||||||
|
= (cute::is_base_of_v<KernelScheduleSm100PtrArrayFastFP32Gemm, KernelScheduleType>);
|
||||||
|
|
||||||
|
// CLCPipeline = PipelineCLCFetchAsync
|
||||||
|
static constexpr auto CLCPipelineStorage
|
||||||
|
= sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
||||||
|
// CLC (scheduler) response
|
||||||
|
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
|
||||||
|
// CLC Throttle pipeline storage
|
||||||
|
static constexpr auto CLCThrottlePipelineStorage
|
||||||
|
= sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
|
||||||
|
// Tmem dealloc
|
||||||
|
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
||||||
|
// Tmem ptr storage
|
||||||
|
static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t);
|
||||||
|
// Tensormap Storage
|
||||||
|
static constexpr size_t TensorMapStorage
|
||||||
|
= IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||||
|
|
||||||
|
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
|
||||||
|
static constexpr auto KernelSmemCarveout = static_cast<int>(CLCPipelineStorage + CLCResponseStorage
|
||||||
|
+ CLCThrottlePipelineStorage + TmemDeallocStorage + TmemBasePtrsStorage + TensorMapStorage);
|
||||||
|
|
||||||
|
// Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations
|
||||||
|
static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
|
||||||
|
|
||||||
|
static constexpr int ScaleGranularityK = get_ScaleGranularityK<LayoutScale>();
|
||||||
|
|
||||||
|
static constexpr auto stage_info
|
||||||
|
= cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_weightonly<
|
||||||
|
Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB,
|
||||||
|
CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{});
|
||||||
|
|
||||||
|
static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info);
|
||||||
|
static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info);
|
||||||
|
static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info);
|
||||||
|
|
||||||
|
static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell");
|
||||||
|
|
||||||
|
using DispatchPolicy
|
||||||
|
= cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput<Load2TransformPipelineStageCount,
|
||||||
|
Transform2MmaPipelineStageCount, SchedulerPipelineStageCount, AccumulatorPipelineStageCount,
|
||||||
|
ClusterShape_MNK>;
|
||||||
|
using CollectiveOp = cutlass::gemm::collective::CollectiveMmaSm100WeightOnly<DispatchPolicy, TileShape_MNK,
|
||||||
|
ElementPairA, StridePairA, ElementPairB, cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, TiledMma,
|
||||||
|
GmemTiledCopyA, SmemLayoutAtomPairA, CopyAtomPairA, cute::identity, GmemTiledCopyB, SmemLayoutAtomPairB,
|
||||||
|
CopyAtomPairB, cute::identity>;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
@ -0,0 +1,42 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective
|
||||||
|
{
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||||
|
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||||
|
class StageCountType, class KernelScheduleType, class Enable = void>
|
||||||
|
struct CollectiveBuilderSm100WeightOnly
|
||||||
|
{
|
||||||
|
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl"
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -0,0 +1,42 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/detail/dependent_false.hpp"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective
|
||||||
|
{
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
|
||||||
|
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
|
||||||
|
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB>
|
||||||
|
struct CollectiveMmaSm100WeightOnly
|
||||||
|
{
|
||||||
|
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user