mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Cherry-pick DeepGEMM related commits from release/1.1.0rc2 (#7716)
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
This commit is contained in:
parent
28469dbf27
commit
4f0e6b5f96
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -28,4 +28,5 @@
|
||||
url = https://github.com/zeromq/cppzmq.git
|
||||
[submodule "3rdparty/DeepGEMM"]
|
||||
path = 3rdparty/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
url = https://github.com/ruoqianguo/DeepGEMM.git
|
||||
branch = swapab_sm100
|
||||
|
||||
2
3rdparty/DeepGEMM
vendored
2
3rdparty/DeepGEMM
vendored
@ -1 +1 @@
|
||||
Subproject commit 89b4089d24216c107f8f805d931a817abb241850
|
||||
Subproject commit 67e3c4d3d09b59405fd6e7698a33db747ed96533
|
||||
@ -949,8 +949,9 @@ class fp8SwapABGemmRunner(TunableRunner):
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
# Encode swap_ab as False (0) and True (1). Currently only add one tactic here.
|
||||
return [0]
|
||||
# Encode swap_ab as False (0) and True (1). Currently enabled when GEMM m <= 128.
|
||||
input, _, _ = inputs
|
||||
return [0, 1] if input.shape[0] <= 128 else [0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -964,9 +965,9 @@ class fp8SwapABGemmRunner(TunableRunner):
|
||||
device=input.device,
|
||||
dtype=self.output_dtype,
|
||||
)
|
||||
# TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value
|
||||
# Treat the default tactic=-1 as swap_ab=False
|
||||
deep_gemm.fp8_gemm_nt(
|
||||
|
||||
forward_func = deep_gemm.fp8_gemm_ntt if tactic == 1 else deep_gemm.fp8_gemm_nt
|
||||
forward_func(
|
||||
(a, a_sf),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user