TensorRT-LLMs/tests/_torch/test_scaled_mm.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

88 lines
2.6 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from warnings import warn
import numpy as np
import pytest
import torch
from utils.util import getSMVersion
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@pytest.mark.skipif(
getSMVersion() != 90,
reason="custom scaled_mm is only supported in SM90",
) # Skip tests that are not supported in SM90
@pytest.mark.parametrize(
"k_n",
[(8192, 10240), (8192, 8192), (8192, 57344), (28672, 8192)],
)
@pytest.mark.parametrize(
"m",
[2048, 12, 228],
)
@pytest.mark.parametrize(
"output_dtype",
[torch.float16, torch.float32, torch.bfloat16],
)
def test_fp8_scaled_mm(output_dtype, m, k_n):
k, n = k_n
torch.random.manual_seed(0)
shape_x = (m, k)
shape_w = (n, k)
x = torch.rand(shape_x, device="cuda").to(torch.float8_e4m3fn)
w = torch.rand(shape_w, device="cuda").to(torch.float8_e4m3fn)
scale_x = torch.rand(1, device="cuda")
scale_w = torch.rand(1, device="cuda")
output = torch.ops.trtllm.cublas_scaled_mm(
x,
w.t(),
scale_x,
scale_w,
bias=None,
out_dtype=output_dtype,
)
cutlass_output = torch.ops.trtllm.cutlass_scaled_mm(
x,
w.t(),
scale_x,
scale_w,
bias=None,
out_dtype=output_dtype,
)
ref = torch._scaled_mm(
x,
w.t(),
out_dtype=output_dtype,
scale_a=scale_x,
scale_b=scale_w,
use_fast_accum=True,
)
np.testing.assert_allclose(ref.float().cpu(), output.float().cpu())
# TODO(zhenhuan): cutlass kernel has acc issue on some shapes
try:
np.testing.assert_allclose(ref.float().cpu(),
cutlass_output.float().cpu())
except Exception as e:
warn(RuntimeWarning("cutlass result is not correct: " + repr(e)))
if __name__ == '__main__':
test_fp8_scaled_mm(torch.float16, 12, (8192, 10240))