Commit Graph

2 Commits

Author SHA1 Message Date
liji-nv
e07fff4f78
[https://nvbugs/5340941] - fix: Correct custom ops used by Qwen3 Moe … (#6285)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-07-25 14:49:45 +08:00
Bo Li
9ae705af1b
perf: Add fused q_norm/k_norm/RoPE for Qwen3. (#4482)
* Add Julien's origina kernel.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Get rid of UpdateKVCache functionality.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Add kernels.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Add torch OP.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Update cmake.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Torch OP must use double as argument dtype.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Add unittest.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Add unittest.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Fix misaligned access when head_dim=64.
In this case, numElemsPerThread=2, numVecPerThread=0. But the store code incorrectly perform vectorized store, some threads (e.g., lane1) issue store to address that is not aligned to 64 bit.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Remove unroll (compiler can do that).
Cleanup code.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Add switch for interleave.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Refactor vectorized load/store.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Implement is_neox. Result not correct yet.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Fix is_neox=True.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

* Add q_weight and k_weight.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>

---------

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2025-05-23 15:31:04 +08:00