mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6342][bug] Fix shape propagation after TP sharding (#7912)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
ba8abeab10
commit
6fd225833c
@ -54,6 +54,25 @@ def _load_hook_remove(
|
||||
state_dict.pop(key, None)
|
||||
|
||||
|
||||
def _update_view_nodes(node: Node) -> None:
|
||||
"""
|
||||
After sharding weights of the linear node, using column split
|
||||
in attention module (Q, K, V),
|
||||
the output Y = X @ W^T is [batch, seq, num_heads // TP_size, head_dim]
|
||||
Some models hardcode the shape of the output to be [batch, seq, num_heads, head_dim]
|
||||
instead of implicit [batch, seq, -1, head_dim].
|
||||
Detect such cases and update the shape of the view node accordingly.
|
||||
"""
|
||||
view_nodes = [n for n in node.users if is_op(n, torch.ops.aten.view)]
|
||||
for view_node in view_nodes:
|
||||
view_shape = view_node.args[1]
|
||||
if len(view_shape) == 4 and view_shape[2] != -1:
|
||||
args = list(view_node.args)
|
||||
args[1] = [view_shape[0], view_shape[1], -1, view_shape[3]]
|
||||
view_node.args = tuple(args)
|
||||
ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}")
|
||||
|
||||
|
||||
def _insert_sharded_matmul(
|
||||
gm: GraphModule,
|
||||
node: Node,
|
||||
@ -157,8 +176,9 @@ def _insert_sharded_matmul(
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# no comm node needed for single device
|
||||
# column shard with no gather: the output is sharded
|
||||
if not add_dist:
|
||||
_update_view_nodes(node)
|
||||
return
|
||||
|
||||
# figure out the right dist op
|
||||
|
||||
Loading…
Reference in New Issue
Block a user