TensorRT-LLMs/triton_kernels/tensor_details/layout.py
Anish Shanbhag 24ac86c485
[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-28 19:56:32 -08:00

45 lines
1.6 KiB
Python

# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout.py
# Triton is licensed under the MIT License.
from .layout_details.base import Layout
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
from .layout_details.blackwell_value import BlackwellMXValueLayout
from .layout_details.hopper_scale import HopperMXScaleLayout
from .layout_details.hopper_value import HopperMXValueLayout
from .layout_details.cdna4_scale import CDNA4MXScaleLayout
from .layout_details.strided import StridedLayout
from ..target_info import cuda_capability_geq, is_hip_cdna4
__all__ = [
"Layout",
"BlackwellMXValueLayout",
"BlackwellMXScaleLayout",
"HopperMXScaleLayout",
"HopperMXValueLayout",
"CDNA4MXScaleLayout",
"StridedLayout",
]
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
if cuda_capability_geq(10):
# return StridedLayout, dict()
return BlackwellMXValueLayout, dict()
elif cuda_capability_geq(9):
return HopperMXValueLayout, {"mx_axis": mx_axis}
else:
return StridedLayout, dict()
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
if is_hip_cdna4():
return CDNA4MXScaleLayout, dict()
else:
if cuda_capability_geq(10):
return BlackwellMXScaleLayout, dict()
elif cuda_capability_geq(9):
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
return StridedLayout, dict()