mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[TRTLLM-10318][feat] Fixing Nemotron sharding: support for sharding buffers (#10319)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Signed-off-by: Lucas <11156568+lucaslie@users.noreply.github.com> Signed-off-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com> Co-authored-by: Lucas <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
cef67b4f8d
commit
7bf4dd9f63
@ -13,7 +13,7 @@ from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op
|
||||
from ...utils.node_utils import extract_weight_name, is_linear_op, is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
|
||||
y2 = y[:, out1:out1+out2]
|
||||
"""
|
||||
# some info we need
|
||||
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
|
||||
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
|
||||
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
|
||||
sizes_unfused = [p.size(0) for p in params_unfused]
|
||||
key_fused = f"fused_weight_{idx}"
|
||||
@ -128,7 +128,7 @@ class QuantizationFusionMixin(ABC):
|
||||
def _insert_fused_quant_gemm(
|
||||
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
|
||||
):
|
||||
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
|
||||
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
|
||||
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
|
||||
sizes_unfused = [p.size(0) for p in params_unfused]
|
||||
key_fused = f"fused_weight_{idx}"
|
||||
|
||||
@ -17,7 +17,7 @@ from ...custom_ops.quant import (
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import (
|
||||
extract_param_names_from_node,
|
||||
extract_weight_nodes,
|
||||
get_quantization_params_from_linear_node,
|
||||
is_bmm_op,
|
||||
is_linear_op,
|
||||
@ -139,13 +139,13 @@ class Quantization(BaseTransform):
|
||||
|
||||
The state_dict is also updated to contain the sharded weights.
|
||||
"""
|
||||
param_name, _ = extract_param_names_from_node(node)
|
||||
original_weight = gm.get_parameter(param_name)
|
||||
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
|
||||
modname, _, attrname = param_name.rpartition(".")
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
|
||||
lin_weight = weight_nodes.weights[0]
|
||||
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
|
||||
modname, _, attrname = lin_weight.node_key.rpartition(".")
|
||||
|
||||
submod = gm.get_submodule(modname)
|
||||
setattr(submod, attrname, new_param)
|
||||
setattr(lin_weight.submod, attrname, new_param)
|
||||
|
||||
# check modelopt quantizers from graph
|
||||
if is_quantized_graph:
|
||||
@ -171,10 +171,12 @@ class Quantization(BaseTransform):
|
||||
)
|
||||
# Note: canonicalize_graph() will remove input/weight/output quantizer
|
||||
|
||||
for scale_name, scale in self.default_scales(original_weight.shape).items():
|
||||
submod.register_buffer(scale_name, scale)
|
||||
for scale_name, scale in self.default_scales(lin_weight.tensor.shape).items():
|
||||
lin_weight.submod.register_buffer(scale_name, scale)
|
||||
|
||||
gm._register_load_state_dict_pre_hook(partial(self.load_hook, weight_name=param_name))
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(self.load_hook, weight_name=lin_weight.node_key)
|
||||
)
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
scales = {}
|
||||
|
||||
@ -39,8 +39,8 @@ from ...utils.node_utils import (
|
||||
LayerSubgraph,
|
||||
LayerType,
|
||||
bfs,
|
||||
extract_param_names_from_node,
|
||||
extract_weight_node,
|
||||
extract_weight_name,
|
||||
extract_weight_nodes,
|
||||
filtered_nodes,
|
||||
get_all_layer_subgraphs,
|
||||
get_layer_after_linear_node,
|
||||
@ -49,7 +49,6 @@ from ...utils.node_utils import (
|
||||
is_any_moe_op,
|
||||
is_any_ssm_op,
|
||||
is_op,
|
||||
num_users_of_weight_node,
|
||||
shape,
|
||||
subgraph,
|
||||
)
|
||||
@ -268,7 +267,7 @@ class WeightShardingInfo(ShardingTransformInfo):
|
||||
min_local_shape: int = 1
|
||||
layer_type: LayerType = LayerType.MLP
|
||||
# used for TP sharding of fused weights
|
||||
fused_weight_dims: Optional[list] = None
|
||||
fused_weight_dims: Optional[tuple] = None
|
||||
|
||||
def quantization_cb(
|
||||
self,
|
||||
@ -437,7 +436,7 @@ class FP8WeightShardingInfo(QuantizationShardingMixin, WeightShardingInfo):
|
||||
|
||||
|
||||
def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, world_size):
|
||||
assert weight_scale.dim() == 1
|
||||
# assert weight_scale.dim() == 1
|
||||
weight_shape_original = list(sharded_uint8_weight_shape)
|
||||
weight_shape_original[dim] = weight_shape_original[dim] * world_size
|
||||
weight_shape_original[-1] *= 2
|
||||
@ -895,13 +894,10 @@ def _load_hook(
|
||||
# This is quite a hacky solution. A better solution would be to store extra_state in
|
||||
# the state_dict to identify whether the state_dict is sharded or not.
|
||||
key = prefix + param_key
|
||||
ad_logger.debug(f"Sharder LOAD hook is called for '{key}'")
|
||||
if key not in state_dict:
|
||||
return
|
||||
p_to_load = state_dict[key]
|
||||
|
||||
p_to_load = p_to_load if param_shape == p_to_load.shape else f_split(p_to_load)
|
||||
|
||||
state_dict[key] = p_to_load
|
||||
|
||||
|
||||
@ -1124,6 +1120,7 @@ def init_process_grid_from_config(
|
||||
ShardingDim.EP: {"p": ep_rank, "w": ep_size},
|
||||
ShardingDim.TP: {"p": tp_rank, "w": tp_size},
|
||||
}
|
||||
ad_logger.info(f"EP + TP sharding process grid: {process_grid}")
|
||||
config.process_grid = process_grid
|
||||
return process_grid
|
||||
|
||||
@ -1187,10 +1184,6 @@ def shard_weight_tensor(
|
||||
fused_dims: list = fused_weight_dims,
|
||||
d: int = dim,
|
||||
) -> torch.Tensor:
|
||||
# dim_d = t.shape[d]
|
||||
# num_parts = 1
|
||||
# part_size = dim_d // num_parts
|
||||
# fused_dims = [part_size] * num_parts
|
||||
return torch.cat(
|
||||
[split_tensor(w) for w in torch.split(t, fused_dims, dim=d)],
|
||||
dim=d,
|
||||
@ -1229,7 +1222,7 @@ def _shard_parameter_node(
|
||||
config: ShardingTransformConfig,
|
||||
add_dist: bool = False,
|
||||
min_local_shape: int = 1,
|
||||
fused_weight_dims: Optional[list] = None,
|
||||
fused_weight_dims: Optional[tuple] = None,
|
||||
quantization_cb: Optional[
|
||||
Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None]
|
||||
] = None,
|
||||
@ -1243,67 +1236,58 @@ def _shard_parameter_node(
|
||||
|
||||
rank, world_size = config.rank, config.world_size
|
||||
allreduce_strategy = config.allreduce_strategy.name
|
||||
num_users = num_users_of_weight_node(node)
|
||||
if num_users > 1 or num_users == 0:
|
||||
ad_logger.warning(
|
||||
f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping."
|
||||
)
|
||||
return
|
||||
# get weight and bias key
|
||||
weight_key, bias_key = extract_param_names_from_node(node)
|
||||
|
||||
modname = weight_key.rpartition(".")[0]
|
||||
submod = gm.get_submodule(modname)
|
||||
|
||||
# Shard weight using the unified function (also updates the parameter)
|
||||
original_weight = gm.get_parameter(weight_key)
|
||||
_, weight_new_shape = shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=original_weight,
|
||||
param_key=weight_key,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
min_local_shape=min_local_shape,
|
||||
fused_weight_dims=fused_weight_dims,
|
||||
)
|
||||
|
||||
if bias_key is not None and dim == 0:
|
||||
# update bias for dim 0 --> we can handle it like the weight
|
||||
original_bias = gm.get_parameter(bias_key)
|
||||
shard_weight_tensor(
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
for weight_node in weight_nodes.weights:
|
||||
_, weight_new_shape = shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=original_bias,
|
||||
param_key=bias_key,
|
||||
weight_tensor=weight_node.tensor,
|
||||
param_key=weight_node.node_key,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
min_local_shape=min_local_shape,
|
||||
fused_weight_dims=fused_weight_dims,
|
||||
)
|
||||
elif bias_key is not None and rank != world_size - 1:
|
||||
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
|
||||
# double counting it. For all other we will delete the bias.
|
||||
args = list(node.args)
|
||||
node_bias = args[2]
|
||||
args[2] = None
|
||||
node.args = tuple(args)
|
||||
gm.graph.erase_node(node_bias)
|
||||
bias_param_name = bias_key.rpartition(".")[-1]
|
||||
setattr(submod, bias_param_name, None)
|
||||
gm._register_load_state_dict_pre_hook(partial(_load_hook_remove, param_key=bias_key))
|
||||
if quantization_cb is not None:
|
||||
quantization_cb(
|
||||
gm=gm,
|
||||
submod=weight_node.submod,
|
||||
node=node,
|
||||
weight_key=weight_node.node_key,
|
||||
weight_new_shape=weight_new_shape,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if quantization_cb is not None:
|
||||
quantization_cb(
|
||||
gm=gm,
|
||||
submod=submod,
|
||||
node=node,
|
||||
weight_key=weight_key,
|
||||
weight_new_shape=weight_new_shape,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
for bias_node in weight_nodes.biases:
|
||||
if dim == 0:
|
||||
# update bias for dim 0 --> we can handle it like the weight
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=bias_node.tensor,
|
||||
param_key=bias_node.node_key,
|
||||
dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
min_local_shape=min_local_shape,
|
||||
fused_weight_dims=fused_weight_dims,
|
||||
)
|
||||
elif bias_node is not None and rank != world_size - 1:
|
||||
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
|
||||
# double counting it. For all other we will delete the bias.
|
||||
args = list(node.args)
|
||||
node_bias = args[2]
|
||||
args[2] = None
|
||||
node.args = tuple(args)
|
||||
gm.graph.erase_node(node_bias)
|
||||
bias_param_name = bias_node.node_key.rpartition(".")[-1]
|
||||
setattr(bias_node.submod, bias_param_name, None)
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(_load_hook_remove, param_key=bias_node.node_key)
|
||||
)
|
||||
|
||||
# # # column shard with no gather: the output is sharded
|
||||
if not add_dist:
|
||||
@ -1633,7 +1617,7 @@ def _process_ssm_sharding(
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
fused_weight_dims=fused_weight_dims["in_proj"],
|
||||
fused_weight_dims=tuple(fused_weight_dims["in_proj"]),
|
||||
layer_type=LayerType.SSM,
|
||||
)
|
||||
):
|
||||
@ -1702,7 +1686,7 @@ def _process_ssm_sharding(
|
||||
fused_dims = None
|
||||
for k, v in fused_weight_dims.items():
|
||||
if k in weight_key:
|
||||
fused_dims = v
|
||||
fused_dims = tuple(v)
|
||||
break
|
||||
|
||||
# Shard the weight tensor (also updates the parameter in the module)
|
||||
@ -1887,7 +1871,7 @@ def _determine_fused_weight_dims(
|
||||
ad_logger.warning(
|
||||
f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping."
|
||||
)
|
||||
return
|
||||
return None
|
||||
chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk))
|
||||
if len(chunk_nodes) > 0:
|
||||
assert len(linear_nodes) == 1
|
||||
@ -1896,6 +1880,8 @@ def _determine_fused_weight_dims(
|
||||
num_chunks = chunk_nodes[0].args[1]
|
||||
weight_dim = shape(linear_node)[2]
|
||||
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
|
||||
if fused_weight_dims is not None:
|
||||
fused_weight_dims = tuple(fused_weight_dims)
|
||||
return fused_weight_dims
|
||||
|
||||
|
||||
@ -2046,9 +2032,9 @@ def detect_sharding_from_config(
|
||||
|
||||
for lin_node in linear_nodes:
|
||||
# use node's weight name to get the module name
|
||||
module_name = extract_weight_node(lin_node).target
|
||||
weight_name = extract_weight_name(lin_node)
|
||||
|
||||
if any(attn_name in module_name for attn_name in attn_names):
|
||||
if any(attn_name in weight_name for attn_name in attn_names):
|
||||
# find the next attention node and infer the head_dim
|
||||
next_attention_node, _ = bfs(
|
||||
lin_node, is_any_attention_op, attr_next="users", include_root=False
|
||||
@ -2072,7 +2058,7 @@ def detect_sharding_from_config(
|
||||
# Then we escape dots, and finally we replace @ with .*
|
||||
pattern_string = pattern_string.replace("*", "@")
|
||||
pattern_regex = re.escape(pattern_string).replace("@", ".*")
|
||||
if re.match(pattern_regex, module_name):
|
||||
if re.match(pattern_regex, weight_name):
|
||||
# we have a match. Get the config for this layer
|
||||
config = tp_plan[key]
|
||||
|
||||
@ -2111,7 +2097,7 @@ def detect_sharding_from_config(
|
||||
elif "local" in config:
|
||||
# Check if this applies to shared experts in EP parallelism.
|
||||
# If yes, apply the TP col-row shard.
|
||||
if "shared" in module_name:
|
||||
if "shared" in weight_name:
|
||||
col_row_action = config.replace("local_", "")
|
||||
if col_row_action == "colwise":
|
||||
transform_container.add(
|
||||
@ -2235,7 +2221,6 @@ def detect_column_row_shard(
|
||||
min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism
|
||||
splitting, e.g., the individual heads into smaller shards.
|
||||
"""
|
||||
# test_moe_variants()
|
||||
ad_logger.debug("Before sharding graph: " + str(gm))
|
||||
config = transform_container.config
|
||||
world_size = config.world_size
|
||||
@ -2340,7 +2325,7 @@ def detect_column_row_shard(
|
||||
# simple shard remaining linear nodes
|
||||
if config.shard_all_unprocessed:
|
||||
num_simple_shards += _process_simple_shard(unprocessed_linear_nodes, transform_container)
|
||||
num_column_row_shards += num_ssm_shards
|
||||
num_column_row_shards += num_ssm_shards + num_mla_shards
|
||||
num_shards = num_simple_shards + num_column_row_shards
|
||||
ad_logger.info(
|
||||
f"Heuristics found {num_shards} TP shards. Simple: {num_simple_shards}, "
|
||||
|
||||
@ -354,7 +354,7 @@ def get_input_embeddings(model: nn.Module) -> torch.Tensor:
|
||||
op="call_function", target=torch.ops.aten.embedding.default
|
||||
)
|
||||
for node in found_nodes:
|
||||
embedding_weights.append(get_weight_tensor(gm, node))
|
||||
embedding_weights.append(get_weight_tensor(node))
|
||||
|
||||
if hasattr(model, "get_input_embeddings"):
|
||||
embedding_weights.append(model.get_input_embeddings())
|
||||
@ -400,7 +400,7 @@ def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Nod
|
||||
def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
|
||||
gm, output_node = get_output_node(model)
|
||||
lm_head_node = get_lm_head_node(gm, output_node)
|
||||
return get_weight_tensor(gm, lm_head_node)
|
||||
return get_weight_tensor(lm_head_node)
|
||||
|
||||
|
||||
def get_attr_by_name(obj, name):
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from torch import nn
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
@ -51,6 +52,19 @@ class LayerSubgraph(BaseModel):
|
||||
min_local_shape: int = 1
|
||||
|
||||
|
||||
class WeightNode(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
node: Node
|
||||
tensor: torch.Tensor
|
||||
node_key: str
|
||||
submod: nn.Module
|
||||
|
||||
|
||||
class WeightNodes(BaseModel):
|
||||
weights: list[WeightNode]
|
||||
biases: list[WeightNode]
|
||||
|
||||
|
||||
@dataclass
|
||||
class modelopt_quant_params:
|
||||
input_node: torch.fx.node.Node = None
|
||||
@ -129,8 +143,26 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
|
||||
return input_params, weight_params, output_params
|
||||
|
||||
|
||||
def extract_weight_node(node: Node) -> int:
|
||||
"""Extracts the weight node from the given parametrized node"""
|
||||
def extract_weight_name(node: Node) -> str:
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
return weight_nodes.weights[0].node_key
|
||||
|
||||
|
||||
def get_param_or_buffer(tensor_name: str, gm: GraphModule) -> torch.Tensor:
|
||||
if tensor_name in dict(gm.named_parameters()):
|
||||
return gm.get_parameter(tensor_name)
|
||||
elif tensor_name in dict(gm.named_buffers()):
|
||||
return gm.get_buffer(tensor_name)
|
||||
else:
|
||||
raise KeyError(f"Tensor {tensor_name} not found in the graph")
|
||||
|
||||
|
||||
def extract_weight_nodes(node: Node) -> WeightNodes:
|
||||
"""Extracts the list of weight node and optional bias node from the given parametrized node"""
|
||||
gm = node.graph.owning_module
|
||||
param_names = {name for name, _ in gm.named_parameters()}.union(
|
||||
{name for name, _ in gm.named_buffers()}
|
||||
)
|
||||
|
||||
def find_get_attr_node(weight_node: Node) -> Node:
|
||||
"""Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
|
||||
@ -141,7 +173,12 @@ def extract_weight_node(node: Node) -> int:
|
||||
torch.ops.aten.view.default,
|
||||
}
|
||||
|
||||
if weight_node.op == "get_attr":
|
||||
if (
|
||||
weight_node.op == "get_attr"
|
||||
and weight_node.target in param_names
|
||||
and has_shape(weight_node)
|
||||
and len(shape(weight_node)) > 0
|
||||
):
|
||||
return weight_node
|
||||
|
||||
# If node is not in the list of allowable ops then return None
|
||||
@ -155,57 +192,56 @@ def extract_weight_node(node: Node) -> int:
|
||||
return None
|
||||
|
||||
if is_op(node, torch.ops.aten.bmm):
|
||||
weight_node = node.args[1]
|
||||
# no bias for bmm
|
||||
weight_node = find_get_attr_node(node.args[1])
|
||||
return WeightNodes(
|
||||
weights=[
|
||||
WeightNode(
|
||||
node=node.args[1],
|
||||
node_key=weight_node.target,
|
||||
tensor=get_param_or_buffer(weight_node.target, gm),
|
||||
submod=gm.get_submodule(weight_node.target.rpartition(".")[0]),
|
||||
)
|
||||
],
|
||||
biases=[],
|
||||
)
|
||||
# for other parametrized nodes, we need to find the weight node
|
||||
else:
|
||||
weight_nodes = [
|
||||
n for n in node.args if isinstance(n, Node) and find_get_attr_node(n) is not None
|
||||
all_weight_nodes = [
|
||||
attr_node
|
||||
for n in node.all_input_nodes
|
||||
if (attr_node := find_get_attr_node(n)) is not None
|
||||
]
|
||||
# can be two weights (if bias weight is present)
|
||||
weight_node = None
|
||||
if weight_nodes:
|
||||
weight_node = weight_nodes[0]
|
||||
# for modelopt quantized graph, there will be a quantize_op
|
||||
_, weight_params, _ = get_quantization_params_from_linear_node(node)
|
||||
weight_node = weight_params.input_node if weight_params else weight_node
|
||||
assert weight_node is not None, "Expected at least one weight node in the parametrized node"
|
||||
return find_get_attr_node(weight_node)
|
||||
# separate weight nodes and bias nodes
|
||||
bias_nodes = [n for n in all_weight_nodes if n.target.endswith("bias")]
|
||||
weight_nodes = [n for n in all_weight_nodes if n not in bias_nodes]
|
||||
weight_nodes = [
|
||||
WeightNode(
|
||||
node=n,
|
||||
node_key=n.target,
|
||||
submod=gm.get_submodule(n.target.rpartition(".")[0]),
|
||||
tensor=get_param_or_buffer(n.target, gm),
|
||||
)
|
||||
for n in weight_nodes
|
||||
]
|
||||
bias_nodes = [
|
||||
WeightNode(
|
||||
node=n,
|
||||
node_key=n.target,
|
||||
submod=gm.get_submodule(n.target.rpartition(".")[0]),
|
||||
tensor=get_param_or_buffer(n.target, gm),
|
||||
)
|
||||
for n in bias_nodes
|
||||
]
|
||||
return WeightNodes(weights=weight_nodes, biases=bias_nodes)
|
||||
|
||||
|
||||
def num_users_of_weight_node(node: Node) -> int:
|
||||
"""Returns the number of users of the weight node of the given parametrized node."""
|
||||
weight_node = extract_weight_node(node)
|
||||
weight_node = extract_weight_nodes(node).weights[0].node
|
||||
return len(weight_node.users) if weight_node is not None else 0
|
||||
|
||||
|
||||
def extract_param_names_from_node(node: Node) -> Tuple[str, Optional[str]]:
|
||||
"""Extracts the name of the parameter associated with the given parametrized node.
|
||||
|
||||
Args:
|
||||
node: node with weight parameters in the graph.
|
||||
"""
|
||||
weight_node = extract_weight_node(node)
|
||||
|
||||
assert weight_node, "Cannot identify weight parameter of linear node."
|
||||
|
||||
# Map arg to named parameter
|
||||
weight_name = weight_node.target
|
||||
|
||||
# check for bias
|
||||
if is_op(node, torch.ops.aten.bmm):
|
||||
bias_node = node.args[2] if len(node.args) > 2 else None
|
||||
else:
|
||||
weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"]
|
||||
if len(weight_nodes) > 1:
|
||||
bias_node = weight_nodes[1]
|
||||
else:
|
||||
bias_node = None
|
||||
assert bias_node is None or bias_node.op == "get_attr"
|
||||
bias_name = bias_node.target if bias_node is not None else None
|
||||
|
||||
return weight_name, bias_name
|
||||
|
||||
|
||||
def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket:
|
||||
"""Get the overload packet from the op overload."""
|
||||
if isinstance(node, OpOverloadPacket):
|
||||
@ -335,6 +371,15 @@ def is_any_ssm_op(node: Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_any_conv_op(node: Node) -> bool:
|
||||
return is_op(
|
||||
node,
|
||||
ops=[
|
||||
torch.ops.auto_deploy.torch_causal_conv1d,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def is_any_attention_op(node: Node) -> bool:
|
||||
return is_op(
|
||||
node,
|
||||
@ -499,8 +544,10 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
|
||||
# closing is the last linear node in the layer
|
||||
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices)
|
||||
if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0:
|
||||
unprocessed_linear_nodes -= set(layer_subgraph.opening_nodes) | set(
|
||||
[layer_subgraph.terminating_node]
|
||||
unprocessed_linear_nodes -= (
|
||||
set(layer_subgraph.opening_nodes)
|
||||
| set([layer_subgraph.terminating_node])
|
||||
| set(layer_subgraph.subgraph_nodes)
|
||||
)
|
||||
layer_subgraphs.append(layer_subgraph)
|
||||
last_lin_index = terminating_indices[-1] + 1
|
||||
@ -746,7 +793,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
|
||||
"""Get the shape of the weight node."""
|
||||
if not is_any_lin_op(node):
|
||||
return None
|
||||
s = list(shape(extract_weight_node(node)))
|
||||
s = list(shape(extract_weight_nodes(node).weights[0].node))
|
||||
if len(s) == 0:
|
||||
return None
|
||||
if is_fp4_op(node):
|
||||
@ -936,10 +983,10 @@ def shape(node: Node) -> Tuple[int, ...]:
|
||||
return node.meta["val"].shape
|
||||
|
||||
|
||||
def get_weight_tensor(gm: GraphModule, node: Node) -> "torch.Tensor":
|
||||
def get_weight_tensor(node: Node) -> torch.Tensor:
|
||||
"""Extract the weight tensor from a node within a GraphModule."""
|
||||
weight_name = extract_param_names_from_node(node)[0]
|
||||
return gm.get_parameter(weight_name)
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
return weight_nodes.weights[0].tensor
|
||||
|
||||
|
||||
def draw_graph(gm: GraphModule, filename: str):
|
||||
|
||||
@ -8,7 +8,7 @@ from torch.fx import GraphModule, Node
|
||||
from ..custom_ops.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX
|
||||
from .logger import ad_logger
|
||||
from .node_utils import (
|
||||
extract_param_names_from_node,
|
||||
extract_weight_name,
|
||||
get_quantization_params_from_linear_node,
|
||||
is_bmm_op,
|
||||
is_linear_op,
|
||||
@ -117,8 +117,8 @@ def should_skip_quantization(
|
||||
else:
|
||||
if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)):
|
||||
return True
|
||||
param_name, _ = extract_param_names_from_node(node_or_name)
|
||||
modname, _, _ = param_name.rpartition(".")
|
||||
weight_name = extract_weight_name(node_or_name)
|
||||
modname = weight_name.rpartition(".")[0]
|
||||
|
||||
return any(fnmatch(modname, pattern) for pattern in excluded_patterns)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Tests for basic graph sharding."""
|
||||
|
||||
from functools import partial
|
||||
from types import SimpleNamespace
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
@ -13,6 +14,7 @@ from _model_test_utils import FakeFP8Linear
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_nemotron_h import NemotronHMamba2Mixer
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
|
||||
FP8WeightShardingInfo,
|
||||
LayerType,
|
||||
@ -35,6 +37,14 @@ base_model_tp_plan = {
|
||||
"linear1": "colwise",
|
||||
"linear2": "rowwise",
|
||||
"linear": "gather",
|
||||
# Mamba2 specific projections
|
||||
"in_proj": "mamba",
|
||||
"out_proj": "rowwise",
|
||||
# MLA specific projections
|
||||
"q_a_proj": "gather",
|
||||
"q_b_proj": "colwise",
|
||||
"kv_a_proj_with_mqa": "gather",
|
||||
"kv_b_proj": "colwise",
|
||||
# "input_layernorm.weight": "sequence_parallel",
|
||||
# "post_attention_layernorm.weight": "sequence_parallel",
|
||||
# "norm.weight": "sequence_parallel",
|
||||
@ -50,7 +60,6 @@ base_model_tp_plan = {
|
||||
}
|
||||
|
||||
predefined_config = {
|
||||
"head_dim": 8,
|
||||
"tp_plan": base_model_tp_plan,
|
||||
}
|
||||
|
||||
@ -125,6 +134,75 @@ class FP8MLP(nn.Module):
|
||||
return self.linear2(y)
|
||||
|
||||
|
||||
class MLA_Block(nn.Module):
|
||||
"""Multi-Latent Attention block - simplified standalone version.
|
||||
|
||||
Based on DeepSeek MLA architecture with KV compression.
|
||||
This is a minimal, self-contained implementation for testing sharding patterns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
q_lora_rank: int,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
# KV compression path (not sharded - gather)
|
||||
self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=bias)
|
||||
|
||||
# KV decompression (sharded column-wise)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim), bias=False
|
||||
)
|
||||
|
||||
# Query path (sharded column-wise)
|
||||
self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=bias)
|
||||
self.q_b_proj = nn.Linear(q_lora_rank, num_heads * self.qk_head_dim, bias=bias)
|
||||
self.q_a_layernorm = nn.LayerNorm(q_lora_rank)
|
||||
# Output projection (sharded row-wise)
|
||||
self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=bias)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s, _ = x.shape
|
||||
|
||||
# Compress KV to latent
|
||||
compressed_kv_rope = self.kv_a_proj_with_mqa(x) # (b, s, kv_lora_rank + rope_dim)
|
||||
compressed_kv = compressed_kv_rope[:, :, : self.kv_lora_rank] # (b, s, kv_lora_rank)
|
||||
|
||||
# Decompress to full K and V
|
||||
kv = self.kv_b_proj(compressed_kv) # (b, s, num_heads * (qk_nope + v))
|
||||
k_nope_v = kv.view(b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope = k_nope_v[:, :, :, : self.qk_nope_head_dim]
|
||||
v = k_nope_v[:, :, :, self.qk_nope_head_dim :]
|
||||
|
||||
# Query projection
|
||||
# q = q_b_proj @ (layernorm(q_a_proj @ x))
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) # (b, s, num_heads * qk_head_dim)
|
||||
q = q.view(b, s, self.num_heads, self.qk_head_dim)
|
||||
q_nope = q[:, :, :, : self.qk_nope_head_dim]
|
||||
|
||||
attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True)
|
||||
attn_out = attn_out.contiguous().view(b, s, -1)
|
||||
# Output projection
|
||||
output = self.o_proj(attn_out)
|
||||
return output
|
||||
|
||||
|
||||
def _run_sharding_execution_job(
|
||||
model_cls: nn.Module,
|
||||
dist_op_expected: str,
|
||||
@ -137,6 +215,7 @@ def _run_sharding_execution_job(
|
||||
batch_size = 4
|
||||
sequence_len = 8
|
||||
num_features = 32
|
||||
skip_output_assert = False
|
||||
|
||||
# GQA specific parameters
|
||||
num_heads = 4
|
||||
@ -150,6 +229,54 @@ def _run_sharding_execution_job(
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == FP8MLP:
|
||||
model = model_cls(num_features, num_features, bias=bias).to("cuda")
|
||||
elif model_cls == NemotronHMamba2Mixer:
|
||||
# Create config for Mamba2 based on Nemotron models
|
||||
# Scaled down from typical values: hidden_size=5120, ssm_state_size=128
|
||||
mamba_config = SimpleNamespace(
|
||||
hidden_size=num_features,
|
||||
ssm_state_size=16, # Scaled from 128
|
||||
mamba_num_heads=num_heads,
|
||||
mamba_head_dim=num_features // num_heads, # 8
|
||||
n_groups=num_heads, # Typical value
|
||||
chunk_size=256,
|
||||
conv_kernel=4,
|
||||
use_conv_bias=bias,
|
||||
use_bias=bias,
|
||||
mamba_hidden_act="silu",
|
||||
layer_norm_epsilon=1e-5,
|
||||
time_step_limit=(0.0, float("inf")),
|
||||
time_step_min=0.001,
|
||||
time_step_max=0.1,
|
||||
time_step_floor=1e-4,
|
||||
initializer_range=0.02,
|
||||
rescale_prenorm_residual=False,
|
||||
residual_in_fp32=False,
|
||||
num_hidden_layers=1,
|
||||
)
|
||||
model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == MLA_Block:
|
||||
# Use actual DeepSeek-V3/R1 production values
|
||||
# From HuggingFace config (HunYuanPretrainedConfig defaults):
|
||||
# hidden_size=4096, num_attention_heads=32
|
||||
# kv_lora_rank=512, q_lora_rank=1536
|
||||
# qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128
|
||||
num_heads_mla = 16
|
||||
qk_nope_head_dim = 64
|
||||
qk_rope_head_dim = 32
|
||||
v_head_dim = 64
|
||||
kv_lora_rank = 256
|
||||
skip_output_assert = True
|
||||
|
||||
model = model_cls(
|
||||
hidden_size=num_features,
|
||||
num_heads=num_heads_mla,
|
||||
q_lora_rank=kv_lora_rank,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
bias=bias,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
else:
|
||||
model = model_cls(num_features, num_features, bias=bias).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
@ -178,6 +305,11 @@ def _run_sharding_execution_job(
|
||||
num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size
|
||||
else:
|
||||
num_params = num_p_og // world_size + num_update
|
||||
if model_cls == MLA_Block:
|
||||
# since q_a_proj is simple-sharded and followed by q_a_layernorm, the layernorm params
|
||||
# are NOT sharded - they have to be replicated. To account for this, we need to add the
|
||||
# number of parameters of the layernorm (weight and bias)to the number of parameters of the model.
|
||||
num_params += 2 * kv_lora_rank * (world_size - 1) // world_size
|
||||
return num_params
|
||||
|
||||
def verify_local_weight_sizes(gm) -> bool:
|
||||
@ -223,6 +355,7 @@ def _run_sharding_execution_job(
|
||||
gm_transformed,
|
||||
check_transformed_graph=combined_graph_check,
|
||||
_get_expected_num_params=_get_expected_num_params,
|
||||
skip_output_assert=skip_output_assert,
|
||||
)
|
||||
|
||||
|
||||
@ -248,6 +381,47 @@ def _run_pattern_detection_job(
|
||||
hidden_size=num_features,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == NemotronHMamba2Mixer:
|
||||
# Create config for Mamba2
|
||||
mamba_config = SimpleNamespace(
|
||||
hidden_size=num_features,
|
||||
ssm_state_size=16,
|
||||
mamba_num_heads=num_heads,
|
||||
mamba_head_dim=num_features // num_heads,
|
||||
n_groups=num_heads,
|
||||
chunk_size=256,
|
||||
conv_kernel=4,
|
||||
use_conv_bias=bias,
|
||||
use_bias=bias,
|
||||
mamba_hidden_act="silu",
|
||||
layer_norm_epsilon=1e-5,
|
||||
time_step_limit=(0.0, float("inf")),
|
||||
time_step_min=0.001,
|
||||
time_step_max=0.1,
|
||||
time_step_floor=1e-4,
|
||||
initializer_range=0.02,
|
||||
rescale_prenorm_residual=False,
|
||||
residual_in_fp32=False,
|
||||
num_hidden_layers=1,
|
||||
)
|
||||
model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16)
|
||||
elif model_cls == MLA_Block:
|
||||
# Create simplified MLA based on DeepSeek-V3 architecture
|
||||
qk_nope_head_dim = 2
|
||||
qk_rope_head_dim = 1
|
||||
v_head_dim = 2
|
||||
kv_lora_rank = 8
|
||||
|
||||
model = model_cls(
|
||||
hidden_size=num_features,
|
||||
num_heads=num_heads,
|
||||
q_lora_rank=kv_lora_rank,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
bias=bias,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
else:
|
||||
model = model_cls(num_features, num_features, bias=bias).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
@ -344,6 +518,102 @@ def _run_pattern_detection_job(
|
||||
min_local_shape=1,
|
||||
)
|
||||
)
|
||||
elif model_cls == NemotronHMamba2Mixer:
|
||||
for node in gm.graph.nodes:
|
||||
if is_linear_op(node):
|
||||
# in_proj should be sharded column-wise
|
||||
# out_proj should be sharded row-wise with all_reduce
|
||||
if "out_proj" in node.args[1].name:
|
||||
dim = SplitDimension.ROW
|
||||
dist_op = "all_reduce"
|
||||
fused_weight_dims = None
|
||||
else:
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = None
|
||||
fused_weight_dims = (
|
||||
num_features,
|
||||
num_features,
|
||||
16 * num_heads,
|
||||
16 * num_heads,
|
||||
num_heads,
|
||||
)
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=dim,
|
||||
config=config,
|
||||
dist_op=dist_op,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=fused_weight_dims,
|
||||
)
|
||||
)
|
||||
if is_op(node, torch.ops.auto_deploy.torch_causal_conv1d):
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=(num_features, 16 * num_heads, 16 * num_heads),
|
||||
)
|
||||
)
|
||||
if is_op(node, torch.ops.auto_deploy.torch_ssm):
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=None,
|
||||
)
|
||||
)
|
||||
if len(node.args) > 1 and "norm_weight" in node.args[0].name:
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.COLUMN,
|
||||
config=config,
|
||||
dist_op=None,
|
||||
min_local_shape=1,
|
||||
layer_type=LayerType.SSM,
|
||||
fused_weight_dims=None,
|
||||
)
|
||||
)
|
||||
elif model_cls == MLA_Block:
|
||||
for node in gm.graph.nodes:
|
||||
if is_linear_op(node):
|
||||
# kv_a_proj_with_mqa: gather (no sharding)
|
||||
# q_b_proj/kv_b_proj: column-wise
|
||||
# o_proj: row-wise with all_reduce
|
||||
min_local_shape = 2
|
||||
if "o_proj" in node.args[1].name:
|
||||
dim = SplitDimension.ROW
|
||||
dist_op = "all_reduce"
|
||||
elif (
|
||||
"kv_a_proj_with_mqa" in node.args[1].name or "q_a_proj" in node.args[1].name
|
||||
):
|
||||
# This is simple-shard gather
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = "all_gather"
|
||||
min_local_shape = 1
|
||||
else:
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = None
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=dim,
|
||||
config=config,
|
||||
dist_op=dist_op,
|
||||
min_local_shape=min_local_shape,
|
||||
layer_type=LayerType.MLA,
|
||||
)
|
||||
)
|
||||
|
||||
# get detected transformations
|
||||
optimizer = InferenceOptimizer(
|
||||
@ -378,6 +648,8 @@ def _run_pattern_detection_job(
|
||||
(FP8MLP, "torch_dist_all_reduce"),
|
||||
(nn.Linear, "torch_dist_all_gather"),
|
||||
(GQA_Block, "torch_dist_all_reduce"),
|
||||
(NemotronHMamba2Mixer, "torch_dist_all_reduce"),
|
||||
(MLA_Block, "torch_dist_all_reduce"),
|
||||
),
|
||||
)
|
||||
def test_sharding(
|
||||
@ -403,6 +675,8 @@ def test_sharding(
|
||||
(FP8MLP, "torch_dist_all_reduce"),
|
||||
(nn.Linear, "torch_dist_all_gather"),
|
||||
(GQA_Block, "torch_dist_all_reduce"),
|
||||
(NemotronHMamba2Mixer, "torch_dist_all_reduce"),
|
||||
(MLA_Block, "torch_dist_all_reduce"),
|
||||
),
|
||||
)
|
||||
def test_sharding_pattern_detection(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user