mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Cherry-pick DeepGEMM related commits from release/1.1.0rc2 (#7716)
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
This commit is contained in:
parent
28469dbf27
commit
4f0e6b5f96
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -28,4 +28,5 @@
|
|||||||
url = https://github.com/zeromq/cppzmq.git
|
url = https://github.com/zeromq/cppzmq.git
|
||||||
[submodule "3rdparty/DeepGEMM"]
|
[submodule "3rdparty/DeepGEMM"]
|
||||||
path = 3rdparty/DeepGEMM
|
path = 3rdparty/DeepGEMM
|
||||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
url = https://github.com/ruoqianguo/DeepGEMM.git
|
||||||
|
branch = swapab_sm100
|
||||||
|
|||||||
2
3rdparty/DeepGEMM
vendored
2
3rdparty/DeepGEMM
vendored
@ -1 +1 @@
|
|||||||
Subproject commit 89b4089d24216c107f8f805d931a817abb241850
|
Subproject commit 67e3c4d3d09b59405fd6e7698a33db747ed96533
|
||||||
@ -949,8 +949,9 @@ class fp8SwapABGemmRunner(TunableRunner):
|
|||||||
inputs: List[torch.Tensor],
|
inputs: List[torch.Tensor],
|
||||||
profile: OptimizationProfile,
|
profile: OptimizationProfile,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
# Encode swap_ab as False (0) and True (1). Currently only add one tactic here.
|
# Encode swap_ab as False (0) and True (1). Currently enabled when GEMM m <= 128.
|
||||||
return [0]
|
input, _, _ = inputs
|
||||||
|
return [0, 1] if input.shape[0] <= 128 else [0]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -964,9 +965,9 @@ class fp8SwapABGemmRunner(TunableRunner):
|
|||||||
device=input.device,
|
device=input.device,
|
||||||
dtype=self.output_dtype,
|
dtype=self.output_dtype,
|
||||||
)
|
)
|
||||||
# TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value
|
|
||||||
# Treat the default tactic=-1 as swap_ab=False
|
forward_func = deep_gemm.fp8_gemm_ntt if tactic == 1 else deep_gemm.fp8_gemm_nt
|
||||||
deep_gemm.fp8_gemm_nt(
|
forward_func(
|
||||||
(a, a_sf),
|
(a, a_sf),
|
||||||
(weight, weight_scale),
|
(weight, weight_scale),
|
||||||
output,
|
output,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user