TensorRT-LLMs/tensorrt_llm/auto_parallel/cluster_info.py
yuxianq b6cfe08c52
fix: Fix NVLink version decoding. (#3996)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
2025-05-06 13:56:50 +08:00

536 lines
15 KiB
Python

import copy
import re
from dataclasses import dataclass, field
from typing import Dict, Tuple, Union
import pynvml
import torch
from cuda import cudart
from tensorrt_llm._utils import DictConversion
from tensorrt_llm.logger import logger
from tensorrt_llm.profiler import PyNVMLContext, _device_get_memory_info_fn
@dataclass
class MathThroughput(DictConversion):
int4: int = 0 # Tflops
int8: int = 0 # Tflops
fp8: int = 0 # Tflops
float16: int = 0 # Tflops
bfloat16: int = 0 # Tflops
float32: int = 0 # Tflops
@staticmethod
def to_tflops(
ipc_per_sm: "MathThroughput",
sm_count: int,
clock_mhz: int,
) -> "MathThroughput":
tflops = MathThroughput()
for name in ipc_per_sm.__dataclass_fields__:
setattr(
tflops, name,
getattr(ipc_per_sm, name) * sm_count * clock_mhz // int(1e6))
return tflops
@dataclass
class ClusterInfo(DictConversion):
inter_node_bw_per_device: int = 25 # GBps
intra_node_bw_per_device: int = 0 # GBps
inter_node_latency: int = 10 # us
intra_node_latency: int = 10 # us
intra_node_sharp: bool = False
inter_node_sharp: bool = True
memory_bw: int = 0 # GBps
memory_budget_per_device: int = 0 # GB
math_throughput: MathThroughput = field(default_factory=MathThroughput)
memory_efficiency: float = 1.0
math_efficiency: float = 1.0
communication_efficiency: float = 1.0
_math_throughputs = {
"A100": MathThroughput(
int8=624,
float16=312,
bfloat16=312,
float32=156,
),
}
_bandwidths = {
"PCIe-3": 16,
"PCIe-4": 32,
"PCIe-5": 64,
}
cluster_infos = {
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
"A100-SXM-80GB":
ClusterInfo(
intra_node_bw_per_device=300,
memory_bw=2039,
memory_budget_per_device=80,
math_throughput=_math_throughputs["A100"],
),
"A100-SXM-40GB":
ClusterInfo(
intra_node_bw_per_device=300,
memory_bw=1555,
memory_budget_per_device=40,
math_throughput=_math_throughputs["A100"],
),
"A100-PCIe-80GB":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=1935,
memory_budget_per_device=80,
math_throughput=_math_throughputs["A100"],
),
"A100-PCIe-40GB":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=1555,
memory_budget_per_device=40,
math_throughput=_math_throughputs["A100"],
),
# from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
"H100-SXM":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
intra_node_sharp=True,
memory_bw=3350,
memory_budget_per_device=80,
math_throughput=MathThroughput(
int8=1979,
fp8=1979,
float16=989,
bfloat16=989,
float32=495,
),
),
"H100-PCIe":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=_bandwidths["PCIe-5"],
memory_bw=2000,
memory_budget_per_device=80,
math_throughput=MathThroughput(
int8=1513,
fp8=1513,
float16=756,
bfloat16=756,
float32=378,
),
),
"H20":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
memory_bw=4000,
memory_budget_per_device=96,
math_throughput=MathThroughput(
int8=293,
fp8=293,
float16=147,
bfloat16=147,
float32=74,
),
),
# from https://nvdam.widen.net/s/nb5zzzsjdf/hpc-datasheet-sc23-h200-datasheet-3002446
"H200-SXM":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
memory_bw=4800,
memory_budget_per_device=141,
math_throughput=MathThroughput(
int8=3958,
fp8=3958,
float16=1979,
bfloat16=1979,
float32=67,
),
),
"H200-NVL":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
memory_bw=4800,
memory_budget_per_device=141,
math_throughput=MathThroughput(
int8=3341,
fp8=3341,
float16=1671,
bfloat16=1671,
float32=60,
),
),
# from https://images.nvidia.cn/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
"A40":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=696,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int4=600,
int8=300,
float16=150,
bfloat16=150,
float32=75,
),
),
# from https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf
"A30":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=933,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int4=661,
int8=330,
float16=165,
bfloat16=165,
float32=82,
),
),
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/datasheet-new/nvidia-a10-datasheet.pdf
"A10":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=600,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int4=500,
int8=250,
float16=125,
bfloat16=125,
float32=62.5,
),
),
"A10G":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=600,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int4=280,
int8=140,
float16=70,
bfloat16=70,
float32=35,
),
),
# from https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413
"L40S":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=864,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int4=733,
int8=733,
fp8=733,
float16=362,
bfloat16=362,
float32=183,
),
),
# from https://images.nvidia.cn/content/Solutions/data-center/vgpu-L40-datasheet.pdf
"L40":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=864,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int4=724,
int8=362,
fp8=362,
float16=181,
bfloat16=181,
float32=90,
),
),
"L20":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=864,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int8=238,
fp8=238,
float16=119,
bfloat16=119,
float32=60,
),
),
# from https://nvdam.widen.net/s/rvq98gbwsw/l4-datasheet-2595652
"L4":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=300,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int8=242,
fp8=242,
float16=120,
bfloat16=120,
float32=60,
),
),
"L2":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=300,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int8=193,
fp8=193,
float16=97,
bfloat16=97,
float32=48,
),
),
}
def infer_cluster_key() -> str:
def match(product, name):
# Use A100 as example, the regex pattern matches for:
# - NVIDIA A100 80GB
# - NVIDIA A100-PCIE
# - NVIDIA A100
# And does not match A1000 etc.
return re.match(f".*{product}([ -]|$).*", name) is not None
def is_sxm():
return "SXM" in device_name
def is_80gb():
return "80GB" in device_name
def is_32gb():
return "32GB" in device_name
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
if match("A100", device_name):
if is_sxm():
if is_80gb():
return "A100-SXM-80GB"
else:
return "A100-SXM-40GB"
else:
if is_80gb():
return "A100-PCIe-80GB"
else:
return "A100-PCIe-40GB"
elif match("A10G", device_name):
return "A10G"
elif match("A10", device_name):
return "A10"
elif match("A30", device_name):
return "A30"
elif match("A40", device_name):
return "A40"
elif match("H100", device_name):
if is_sxm():
return "H100-SXM"
else:
return "H100-PCIe"
elif match("H200", device_name):
if is_sxm():
return "H200-SXM"
else:
return "H200-NVL"
elif match("L40S", device_name):
return "L40S"
elif match("L40", device_name):
return "L40"
elif match("L4", device_name):
return "L4"
return None
def ipc_per_sm(compute_cap: Tuple[int, int]) -> MathThroughput:
ipc_table = {
(9, 0):
MathThroughput(
int8=16384,
fp8=16384,
float16=8192,
bfloat16=8192,
float32=4096,
),
(8, 0):
MathThroughput(
int4=8192,
int8=4096,
float16=2048,
bfloat16=2048,
float32=1024,
),
(8, 6):
MathThroughput(
int4=4096,
int8=2048,
float16=1024,
bfloat16=1024,
float32=512,
),
(8, 9):
MathThroughput(
int4=2048,
int8=1024,
fp8=1024,
float16=512,
bfloat16=512,
float32=256,
),
(7, 0):
MathThroughput(
float16=1024,
float32=128,
),
(7, 5):
MathThroughput(
int4=4096,
int8=2048,
float16=1024,
float32=128,
),
}
return ipc_table.get(compute_cap, MathThroughput())
def nvlink_version(version_enum: int) -> int:
nvl_version_table = {
1: 1,
2: 2,
3: 2, # 2.2
4: 3,
5: 3, # 3.1
6: 4,
7: 5,
}
return nvl_version_table[version_enum]
def nvlink_bandwidth(nvlink_version: int) -> int:
nvl_bw_table = {
1: 80,
2: 150,
3: 300,
4: 450,
5: 900,
}
return nvl_bw_table[nvlink_version]
def infer_cluster_info() -> ClusterInfo:
device = torch.cuda.current_device()
index = device.index if isinstance(device, torch.device) else device
with PyNVMLContext():
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
logger.info(f"Compute capability: {compute_cap}")
err, properties = cudart.cudaGetDeviceProperties(index)
sm_count = properties.multiProcessorCount
logger.info(f"SM count: {sm_count}")
sm_clock = pynvml.nvmlDeviceGetMaxClockInfo(
handle,
pynvml.NVML_CLOCK_SM,
)
logger.info(f"SM clock: {sm_clock} MHz")
math_throughput = MathThroughput.to_tflops(
ipc_per_sm(compute_cap),
sm_count,
sm_clock,
)
for name in math_throughput.__dataclass_fields__:
tflops = getattr(math_throughput, name)
logger.info(f"{name} TFLOPS: {tflops}")
mem_info = _device_get_memory_info_fn(handle)
memory_budget = mem_info.total // (1024**3)
logger.info(f"Total Memory: {memory_budget} GiB")
mem_clock = pynvml.nvmlDeviceGetMaxClockInfo(
handle,
pynvml.NVML_CLOCK_MEM,
)
logger.info(f"Memory clock: {mem_clock} MHz")
mem_bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
logger.info(f"Memory bus width: {mem_bus_width}")
memory_bw = mem_bus_width * mem_clock * 2 // int(8e3)
logger.info(f"Memory bandwidth: {memory_bw} GB/s")
try:
is_nvl_active = bool(pynvml.nvmlDeviceGetNvLinkState(handle, 0))
logger.info(f"NVLink is active: {is_nvl_active}")
except pynvml.NVMLError:
is_nvl_active = False
intra_node_sharp = False
if is_nvl_active:
nvl_version_enum = pynvml.nvmlDeviceGetNvLinkVersion(handle, 0)
nvl_version = nvlink_version(nvl_version_enum)
logger.info(f"NVLink version: {nvl_version}")
nvl_bw = nvlink_bandwidth(nvl_version)
logger.info(f"NVLink bandwidth (unidirectional): {nvl_bw} GB/s")
intra_node_bw = nvl_bw
if nvl_version >= 4:
intra_node_sharp = True
else:
pcie_speed = pynvml.nvmlDeviceGetPcieSpeed(handle)
logger.info(f"PCIe speed: {pcie_speed} Mbps")
pcie_link_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(handle)
logger.info(f"PCIe link width: {pcie_link_width}")
pcie_bw = pcie_speed * pcie_link_width // int(8e3)
logger.info(f"PCIe bandwidth: {pcie_bw} GB/s")
intra_node_bw = pcie_bw
cluster_info = ClusterInfo(
math_throughput=math_throughput,
memory_bw=memory_bw,
memory_budget_per_device=memory_budget,
intra_node_bw_per_device=intra_node_bw,
intra_node_sharp=intra_node_sharp,
)
return cluster_info
def infer_cluster_config() -> Dict[str, Union[str, ClusterInfo]]:
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
cluster_key = infer_cluster_key()
if cluster_key is not None:
return dict(cluster_key=cluster_key)
else:
try:
cluster_info = infer_cluster_info()
except pynvml.NVMLError:
fallback_cluster_key = "L40"
cluster_info = copy.copy(cluster_infos[fallback_cluster_key])
memory_budget = torch.cuda.mem_get_info()[1] // (1024**3)
cluster_info.memory_budget_per_device = memory_budget
logger.warning(
f"Failed to infer cluster info for {device_name}, "
f"treat it as a {fallback_cluster_key} node with {memory_budget} GB memory. "
"This setting makes no effect if you do not use auto parallel.")
return dict(
cluster_key=device_name.replace(" ", "-"),
cluster_info=cluster_info,
)
if __name__ == "__main__":
logger.set_level("info")
infer_cluster_info()