diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 3fa5c877a2..a0a1734fcc 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -49,6 +49,7 @@ from ...utils.node_utils import ( is_any_moe_op, is_any_ssm_op, is_op, + num_users_of_weight_node, shape, subgraph, ) @@ -1237,6 +1238,13 @@ def _shard_parameter_node( rank, world_size = config.rank, config.world_size allreduce_strategy = config.allreduce_strategy.name + num_users = num_users_of_weight_node(node) + if num_users > 1 or num_users == 0: + ad_logger.warning( + f"Expected exactly one user for the weight node {node.name}, but found {num_users}" + ) + return + # Shard weight using the unified function (also updates the parameter) weight_nodes = extract_weight_nodes(node) for weight_node in weight_nodes.weights: