mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-08 20:21:48 +08:00
* feat: use NVRTC for DeepGEMM JIT compilation Signed-off-by: Zihua Wu * fix: add license Signed-off-by: Zihua Wu * feat: store NVRTC JIT results in memory by default Signed-off-by: Zihua Wu * feat: refinement Signed-off-by: Zihua Wu * feat: refinement Signed-off-by: Zihua Wu * test: set timeout to 7200 Signed-off-by: Zihua Wu --------- Signed-off-by: Zihua Wu
194 lines
7.6 KiB
Python
194 lines
7.6 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import itertools
|
|
from typing import List, Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
from _torch.helpers import calc_diff, ceil_div, per_block_cast_to_fp8
|
|
from utils.util import getSMVersion
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
getSMVersion() != 90,
|
|
reason="The test is for Hopper only. Current SM is %d." % getSMVersion(),
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"k, n",
|
|
[(7168, 2112), (2048, 7168)],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"m",
|
|
[7, 64, 128, 4096],
|
|
)
|
|
def test_fp8_block_scale_gemm(m, k, n):
|
|
torch.random.manual_seed(0)
|
|
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) / k
|
|
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) / k
|
|
|
|
act_a_fp8, act_a_sf = torch.ops.trtllm.fp8_quantize_1x128(a)
|
|
act_b_fp8, act_b_sf = per_block_cast_to_fp8(b)
|
|
|
|
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_a_fp8, act_b_fp8,
|
|
act_a_sf, act_b_sf)
|
|
|
|
output_expected = a @ b.t()
|
|
diff = calc_diff(output, output_expected)
|
|
assert diff < 1e-3
|
|
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
def change_to_offset_layout(
|
|
ms: List[int],
|
|
x_fp8: torch.Tensor,
|
|
x_scale: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
x_list = []
|
|
x_scale_list = []
|
|
shape_m_total = 0
|
|
num_problems = len(ms)
|
|
m_acc = [0] + list(itertools.accumulate(ms))
|
|
|
|
for i in range(num_problems):
|
|
ms[i]
|
|
x_list.append(x_fp8[m_acc[i]:m_acc[i + 1]])
|
|
x_scale_padded = x_scale[m_acc[i]:m_acc[i + 1]]
|
|
if x_scale_padded.shape[0] % 32 != 0:
|
|
x_empty = torch.zeros(
|
|
[32 - (x_scale_padded.shape[0] % 32), x_scale_padded.shape[1]],
|
|
dtype=x_scale_padded.dtype,
|
|
device=x_scale_padded.device,
|
|
)
|
|
x_scale_padded = torch.cat([x_scale_padded, x_empty])
|
|
x_scale_list.append(x_scale_padded)
|
|
|
|
shape_m_total = m_acc[-1]
|
|
ret_x = torch.cat(x_list)
|
|
ret_x_scale = torch.cat(x_scale_list)
|
|
ret_x_scale = ret_x_scale.t().contiguous()
|
|
pad_target = ceil_div(shape_m_total + num_problems * 31, 32) * 32
|
|
pad_target -= ret_x_scale.shape[1]
|
|
ret_x_scale = torch.nn.functional.pad(ret_x_scale, (0, pad_target),
|
|
mode='constant',
|
|
value=0)
|
|
|
|
return ret_x, ret_x_scale
|
|
|
|
|
|
def construct_grouped(
|
|
ms: List[int], k: int, n: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor]:
|
|
assert all(m % 4 == 0 for m in ms), f'TMA alignment error: {ms}'
|
|
torch.random.manual_seed(0)
|
|
num_groups = len(ms)
|
|
x = torch.randn((sum(ms), k), device='cuda', dtype=torch.bfloat16) / k
|
|
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) / k
|
|
m_acc = [0] + list(itertools.accumulate(ms))
|
|
ref_out = torch.empty((sum(ms), n), device='cuda', dtype=torch.bfloat16)
|
|
for i in range(num_groups):
|
|
ref_out[m_acc[i]:m_acc[i + 1]] = torch.einsum('mk,nk->mn',
|
|
x[m_acc[i]:m_acc[i + 1]],
|
|
y[i])
|
|
|
|
x_fp8, x_scale = (torch.empty_like(x, dtype=torch.float8_e4m3fn),
|
|
torch.empty((sum(ms), k // 128),
|
|
device='cuda',
|
|
dtype=torch.float))
|
|
y_fp8, y_scale = (torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
|
torch.empty((num_groups, (n + 127) // 128, k // 128),
|
|
device='cuda',
|
|
dtype=torch.float))
|
|
|
|
for i in range(num_groups):
|
|
xi = x[m_acc[i]:m_acc[i + 1]]
|
|
yi = y[i]
|
|
x_fp8_i, x_scale_i = torch.ops.trtllm.fp8_quantize_1x128(xi)
|
|
x_fp8[m_acc[i]:m_acc[i + 1]] = x_fp8_i.view(
|
|
x_fp8[m_acc[i]:m_acc[i + 1]].shape)
|
|
x_scale[m_acc[i]:m_acc[i + 1]] = x_scale_i.view(
|
|
x_scale[m_acc[i]:m_acc[i + 1]].shape[::-1]).t().contiguous()
|
|
|
|
y_fp8_i, y_scale_i = per_block_cast_to_fp8(yi)
|
|
y_fp8[i] = y_fp8_i.view(y_fp8[i].shape)
|
|
y_scale[i] = y_scale_i.view(y_scale[i].shape)
|
|
|
|
return x_fp8, x_scale, y_fp8, y_scale, ref_out
|
|
|
|
|
|
def construct_batched(
|
|
num_batches: int, m: int, k: int, n: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor]:
|
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
|
|
|
torch.random.manual_seed(0)
|
|
x = torch.randn(
|
|
(num_batches, m, k), device='cuda', dtype=torch.bfloat16) / k
|
|
y = torch.randn(
|
|
(num_batches, n, k), device='cuda', dtype=torch.bfloat16) / k
|
|
ref_out = torch.einsum('bmk,bnk->bmn', x, y)
|
|
|
|
x_fp8, x_scale = (torch.empty_like(x, dtype=torch.float8_e4m3fn),
|
|
torch.empty((num_batches, m, k // 128),
|
|
device='cuda',
|
|
dtype=torch.float))
|
|
y_fp8, y_scale = (torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
|
torch.empty((num_batches, (n + 127) // 128, k // 128),
|
|
device='cuda',
|
|
dtype=torch.float))
|
|
|
|
for i in range(num_batches):
|
|
x_fp8[i], x_scale_i = torch.ops.trtllm.fp8_quantize_1x128(x[i])
|
|
x_scale[i] = x_scale_i.view(x_scale[i].shape)
|
|
y_fp8[i], y_scale[i] = per_block_cast_to_fp8(y[i])
|
|
|
|
return x_fp8, x_scale, y_fp8, y_scale, ref_out
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
getSMVersion() != 90,
|
|
reason="Op only supported on Hopper, current SM is %d." % getSMVersion(),
|
|
)
|
|
@pytest.mark.parametrize("ms", [[256, 256], [128, 64, 64], [16, 24, 48]])
|
|
@pytest.mark.parametrize("k, n", [(7168, 4096), (2048, 7168)])
|
|
def test_fp8_block_scaling_moe_gemm(ms, k, n):
|
|
offset_cpu = [0] + list(itertools.accumulate(ms))
|
|
offset = torch.tensor(offset_cpu, device='cuda', dtype=torch.int64)
|
|
x_fp8, x_scale, y_fp8, y_scale, ref_out = construct_grouped(ms, k, n)
|
|
x_fp8, x_scale = change_to_offset_layout(ms, x_fp8, x_scale)
|
|
out = torch.ops.trtllm.fp8_block_scaling_moe_gemm(x_fp8, y_fp8, x_scale,
|
|
y_scale, offset)
|
|
diff = calc_diff(out, ref_out)
|
|
assert diff < 1e-3
|
|
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
getSMVersion() != 90,
|
|
reason="Op only supported on Hopper, current SM is %d." % getSMVersion(),
|
|
)
|
|
@pytest.mark.parametrize("batch_size, m", [(1, 1024), (2, 512), (4, 256)])
|
|
@pytest.mark.parametrize("k, n", [(7168, 4096), (2048, 7168)])
|
|
def test_fp8_block_scaling_bmm(batch_size, m, k, n):
|
|
torch.random.manual_seed(0)
|
|
x_fp8, x_scale, y_fp8, y_scale, ref_out = construct_batched(
|
|
batch_size, m, k, n)
|
|
output = torch.ops.trtllm.fp8_block_scaling_bmm(x_fp8, y_fp8, x_scale,
|
|
y_scale)
|
|
diff = calc_diff(output, ref_out)
|
|
assert diff < 1e-3
|
|
torch.testing.assert_close(output, ref_out, atol=1e-3, rtol=1e-3)
|