TensorRT-LLMs/triton_kernels/numerics.py
Anish Shanbhag 24ac86c485
[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-28 19:56:32 -08:00

47 lines
1.2 KiB
Python

# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/numerics.py
# Triton is licensed under the MIT License.
import torch
from dataclasses import dataclass
MAX_FINITE_FLOAT8E5 = 57344.0
MAX_FINITE_FLOAT8E4NV = 448.0
MAX_FINITE_FLOAT8E4B8 = 240.0
@dataclass(frozen=True)
class BaseFlexData:
dtype: torch.dtype | None = None
def view(self, x: torch.Tensor):
if self.dtype is None:
return x
return x.view(self.dtype)
def reinterpret(self, x):
if self.dtype is None or x.dtype.itemsize > 1:
return x
return x.view(self.dtype)
@dataclass(frozen=True)
class InFlexData(BaseFlexData):
scale: torch.Tensor | None = None
@property
def is_per_batch(self):
return False if self.scale is None else len(self.scale) > 1
@dataclass(frozen=True)
class OutFlexData(BaseFlexData):
expected_scale: torch.Tensor | None = None
actual_scale: torch.Tensor | None = None
checksum_scale: torch.Tensor | None = None
def __iter__(self):
yield self.expected_scale
yield self.actual_scale
yield self.checksum_scale