mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[TRTLLM-10785][feat] Fix sharding dashboard errors (#10786)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
58311b2345
commit
eb326073d8
@ -49,6 +49,7 @@ from ...utils.node_utils import (
|
||||
is_any_moe_op,
|
||||
is_any_ssm_op,
|
||||
is_op,
|
||||
num_users_of_weight_node,
|
||||
shape,
|
||||
subgraph,
|
||||
)
|
||||
@ -1237,6 +1238,13 @@ def _shard_parameter_node(
|
||||
rank, world_size = config.rank, config.world_size
|
||||
allreduce_strategy = config.allreduce_strategy.name
|
||||
|
||||
num_users = num_users_of_weight_node(node)
|
||||
if num_users > 1 or num_users == 0:
|
||||
ad_logger.warning(
|
||||
f"Expected exactly one user for the weight node {node.name}, but found {num_users}"
|
||||
)
|
||||
return
|
||||
|
||||
# Shard weight using the unified function (also updates the parameter)
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
for weight_node in weight_nodes.weights:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user