mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
23 lines
589 B
Python
23 lines
589 B
Python
import torch
|
|
|
|
# The declarations must be aligned with thUtils.h
|
|
SF_DTYPE = torch.uint8
|
|
FLOAT4_E2M1X2 = torch.uint8
|
|
|
|
|
|
def pad_up(x: int, y: int) -> int:
|
|
return ((x + y - 1) // y) * y
|
|
|
|
|
|
# For GEMM autotuning.
|
|
# Taken from https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/include/tensorrt_llm/runtime//modelConfig.h#L38
|
|
# TODO: move to model config, tune for blackwell hardware
|
|
FP4_BUCKETS = [64, 128, 256, 512, 1024]
|
|
|
|
# Export
|
|
float4_e2m1x2 = FLOAT4_E2M1X2
|
|
float4_sf_dtype = SF_DTYPE
|
|
fp4_buckets = FP4_BUCKETS
|
|
|
|
__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']
|