TensorRT-LLMs/tensorrt_llm/quantization/utils/fp4_utils.py
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

23 lines
589 B
Python

import torch
# The declarations must be aligned with thUtils.h
SF_DTYPE = torch.uint8
FLOAT4_E2M1X2 = torch.uint8
def pad_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
# For GEMM autotuning.
# Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38
# TODO: move to model config, tune for blackwell hardware
FP4_BUCKETS = [64, 128, 256, 512, 1024]
# Export
float4_e2m1x2 = FLOAT4_E2M1X2
float4_sf_dtype = SF_DTYPE
fp4_buckets = FP4_BUCKETS
__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']