mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
|
|
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout.py
|
|
# Triton is licensed under the MIT License.
|
|
|
|
from .layout_details.base import Layout
|
|
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
|
|
from .layout_details.blackwell_value import BlackwellMXValueLayout
|
|
from .layout_details.hopper_scale import HopperMXScaleLayout
|
|
from .layout_details.hopper_value import HopperMXValueLayout
|
|
from .layout_details.cdna4_scale import CDNA4MXScaleLayout
|
|
from .layout_details.strided import StridedLayout
|
|
from ..target_info import cuda_capability_geq, is_hip_cdna4
|
|
|
|
__all__ = [
|
|
"Layout",
|
|
"BlackwellMXValueLayout",
|
|
"BlackwellMXScaleLayout",
|
|
"HopperMXScaleLayout",
|
|
"HopperMXValueLayout",
|
|
"CDNA4MXScaleLayout",
|
|
"StridedLayout",
|
|
]
|
|
|
|
|
|
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
|
|
if cuda_capability_geq(10):
|
|
# return StridedLayout, dict()
|
|
return BlackwellMXValueLayout, dict()
|
|
elif cuda_capability_geq(9):
|
|
return HopperMXValueLayout, {"mx_axis": mx_axis}
|
|
else:
|
|
return StridedLayout, dict()
|
|
|
|
|
|
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
|
|
if is_hip_cdna4():
|
|
return CDNA4MXScaleLayout, dict()
|
|
else:
|
|
if cuda_capability_geq(10):
|
|
return BlackwellMXScaleLayout, dict()
|
|
elif cuda_capability_geq(9):
|
|
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
|
|
|
|
return StridedLayout, dict()
|