mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Fix NVLink version decoding. (#3996)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
5a4794b387
commit
b6cfe08c52
@ -414,11 +414,11 @@ def nvlink_version(version_enum: int) -> int:
|
||||
nvl_version_table = {
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 2,
|
||||
4: 2,
|
||||
5: 3,
|
||||
6: 3,
|
||||
7: 4,
|
||||
3: 2, # 2.2
|
||||
4: 3,
|
||||
5: 3, # 3.1
|
||||
6: 4,
|
||||
7: 5,
|
||||
}
|
||||
return nvl_version_table[version_enum]
|
||||
|
||||
@ -429,6 +429,7 @@ def nvlink_bandwidth(nvlink_version: int) -> int:
|
||||
2: 150,
|
||||
3: 300,
|
||||
4: 450,
|
||||
5: 900,
|
||||
}
|
||||
return nvl_bw_table[nvlink_version]
|
||||
|
||||
@ -483,7 +484,7 @@ def infer_cluster_info() -> ClusterInfo:
|
||||
nvl_version = nvlink_version(nvl_version_enum)
|
||||
logger.info(f"NVLink version: {nvl_version}")
|
||||
nvl_bw = nvlink_bandwidth(nvl_version)
|
||||
logger.info(f"NVLink bandwidth: {nvl_bw} GB/s")
|
||||
logger.info(f"NVLink bandwidth (unidirectional): {nvl_bw} GB/s")
|
||||
intra_node_bw = nvl_bw
|
||||
if nvl_version >= 4:
|
||||
intra_node_sharp = True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user