mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
[https://nvbugs/5819002][fix] fix sharding tests (#10775)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
d43be7b65e
commit
d8e6e22060
@ -43,6 +43,7 @@ from ...utils.node_utils import (
|
||||
extract_weight_nodes,
|
||||
filtered_nodes,
|
||||
get_all_layer_subgraphs,
|
||||
get_all_weights_in_subgraph,
|
||||
get_layer_after_linear_node,
|
||||
is_any_attention_op,
|
||||
is_any_lin_op,
|
||||
@ -1060,31 +1061,6 @@ def _resolve_tp_cls_from_node(node: Node):
|
||||
return WeightShardingInfo
|
||||
|
||||
|
||||
def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int:
|
||||
"""Helper to get the first dimension size of an argument (Node or Tensor)."""
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return arg.shape[0]
|
||||
if isinstance(arg, Node):
|
||||
if arg.op == "get_attr":
|
||||
# Traverse attributes to find the tensor
|
||||
obj = gm
|
||||
for atom in arg.target.split("."):
|
||||
obj = getattr(obj, atom)
|
||||
return obj.shape[0]
|
||||
if "val" in arg.meta:
|
||||
return shape(arg)[0]
|
||||
raise ValueError(f"Cannot determine shape[0] for {arg}")
|
||||
|
||||
|
||||
def get_all_weights_in_subgraph(
|
||||
sources: list[Node],
|
||||
sinks: list[Node],
|
||||
):
|
||||
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
|
||||
weight_nodes = subgraph(sources, sinks, include=lambda n: n.op == "get_attr")
|
||||
return weight_nodes
|
||||
|
||||
|
||||
def init_process_grid_from_config(
|
||||
config: ShardingTransformConfig,
|
||||
) -> Dict[ShardingDim, Dict[str, int]]:
|
||||
|
||||
@ -143,8 +143,19 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
|
||||
return input_params, weight_params, output_params
|
||||
|
||||
|
||||
def extract_weight_name(node: Node) -> str:
|
||||
def get_all_weights_in_subgraph(
|
||||
sources: list[Node],
|
||||
sinks: list[Node],
|
||||
):
|
||||
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
|
||||
weight_nodes = subgraph(sources, sinks, include=is_weight_node)
|
||||
return weight_nodes
|
||||
|
||||
|
||||
def extract_weight_name(node: Node) -> Union[str, bool]:
|
||||
weight_nodes = extract_weight_nodes(node)
|
||||
if len(weight_nodes.weights) == 0:
|
||||
return False
|
||||
return weight_nodes.weights[0].node_key
|
||||
|
||||
|
||||
@ -431,6 +442,10 @@ def is_dist_op(node: Node) -> bool:
|
||||
return is_op(node, dist_ops)
|
||||
|
||||
|
||||
def is_weight_node(node: Node) -> bool:
|
||||
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
|
||||
|
||||
|
||||
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
|
||||
"""Get a user from a node if the node matches a given op set and num of users."""
|
||||
if node is None:
|
||||
|
||||
@ -353,10 +353,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mt
|
||||
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5814203)
|
||||
unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py::test_flashinfer_fused_moe_matches_torch_moe SKIP (https://nvbugs/5814215)
|
||||
full:sm89/accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5814504)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-False-8] SKIP (https://nvbugs/5819002)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-True-8] SKIP (https://nvbugs/5819002)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-False-8] SKIP (https://nvbugs/5819002)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-True-8] SKIP (https://nvbugs/5819002)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5819005)
|
||||
unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks SKIP (https://nvbugs/5819014)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5819019)
|
||||
|
||||
@ -572,7 +572,9 @@ def _run_pattern_detection_job(
|
||||
fused_weight_dims=None,
|
||||
)
|
||||
)
|
||||
if len(node.args) > 1 and "norm_weight" in node.args[0].name:
|
||||
if len(node.args) > 1 and (
|
||||
"norm_weight" in node.args[0].name or "a_log" in node.args[0].name
|
||||
):
|
||||
expected_transformations.append(
|
||||
WeightShardingInfo(
|
||||
target_node=node.name,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user