fix: Fix NVLink version decoding. (#3996)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
yuxianq 2025-05-06 13:56:50 +08:00 committed by GitHub
parent 5a4794b387
commit b6cfe08c52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -414,11 +414,11 @@ def nvlink_version(version_enum: int) -> int:
nvl_version_table = {
1: 1,
2: 2,
3: 2,
4: 2,
5: 3,
6: 3,
7: 4,
3: 2, # 2.2
4: 3,
5: 3, # 3.1
6: 4,
7: 5,
}
return nvl_version_table[version_enum]
@ -429,6 +429,7 @@ def nvlink_bandwidth(nvlink_version: int) -> int:
2: 150,
3: 300,
4: 450,
5: 900,
}
return nvl_bw_table[nvlink_version]
@ -483,7 +484,7 @@ def infer_cluster_info() -> ClusterInfo:
nvl_version = nvlink_version(nvl_version_enum)
logger.info(f"NVLink version: {nvl_version}")
nvl_bw = nvlink_bandwidth(nvl_version)
logger.info(f"NVLink bandwidth: {nvl_bw} GB/s")
logger.info(f"NVLink bandwidth (unidirectional): {nvl_bw} GB/s")
intra_node_bw = nvl_bw
if nvl_version >= 4:
intra_node_sharp = True