[TRTLLM-6342][bug] Fix shape propagation after TP sharding (#7912)

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Grzegorz Kwasniewski 2025-10-01 17:15:46 +02:00 committed by GitHub
parent ba8abeab10
commit 6fd225833c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,6 +54,25 @@ def _load_hook_remove(
state_dict.pop(key, None)
def _update_view_nodes(node: Node) -> None:
"""
After sharding weights of the linear node, using column split
in attention module (Q, K, V),
the output Y = X @ W^T is [batch, seq, num_heads // TP_size, head_dim]
Some models hardcode the shape of the output to be [batch, seq, num_heads, head_dim]
instead of implicit [batch, seq, -1, head_dim].
Detect such cases and update the shape of the view node accordingly.
"""
view_nodes = [n for n in node.users if is_op(n, torch.ops.aten.view)]
for view_node in view_nodes:
view_shape = view_node.args[1]
if len(view_shape) == 4 and view_shape[2] != -1:
args = list(view_node.args)
args[1] = [view_shape[0], view_shape[1], -1, view_shape[3]]
view_node.args = tuple(args)
ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}")
def _insert_sharded_matmul(
gm: GraphModule,
node: Node,
@ -157,8 +176,9 @@ def _insert_sharded_matmul(
world_size=world_size,
)
# no comm node needed for single device
# column shard with no gather: the output is sharded
if not add_dist:
_update_view_nodes(node)
return
# figure out the right dist op