[https://nvbugs/5819002][fix] fix sharding tests (#10775)

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Grzegorz Kwasniewski 2026-01-22 20:02:48 +01:00 committed by GitHub
parent d43be7b65e
commit d8e6e22060
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 31 deletions

View File

@ -43,6 +43,7 @@ from ...utils.node_utils import (
extract_weight_nodes,
filtered_nodes,
get_all_layer_subgraphs,
get_all_weights_in_subgraph,
get_layer_after_linear_node,
is_any_attention_op,
is_any_lin_op,
@ -1060,31 +1061,6 @@ def _resolve_tp_cls_from_node(node: Node):
return WeightShardingInfo
def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int:
"""Helper to get the first dimension size of an argument (Node or Tensor)."""
if isinstance(arg, torch.Tensor):
return arg.shape[0]
if isinstance(arg, Node):
if arg.op == "get_attr":
# Traverse attributes to find the tensor
obj = gm
for atom in arg.target.split("."):
obj = getattr(obj, atom)
return obj.shape[0]
if "val" in arg.meta:
return shape(arg)[0]
raise ValueError(f"Cannot determine shape[0] for {arg}")
def get_all_weights_in_subgraph(
sources: list[Node],
sinks: list[Node],
):
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
weight_nodes = subgraph(sources, sinks, include=lambda n: n.op == "get_attr")
return weight_nodes
def init_process_grid_from_config(
config: ShardingTransformConfig,
) -> Dict[ShardingDim, Dict[str, int]]:

View File

@ -143,8 +143,19 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
return input_params, weight_params, output_params
def extract_weight_name(node: Node) -> str:
def get_all_weights_in_subgraph(
sources: list[Node],
sinks: list[Node],
):
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
weight_nodes = subgraph(sources, sinks, include=is_weight_node)
return weight_nodes
def extract_weight_name(node: Node) -> Union[str, bool]:
weight_nodes = extract_weight_nodes(node)
if len(weight_nodes.weights) == 0:
return False
return weight_nodes.weights[0].node_key
@ -431,6 +442,10 @@ def is_dist_op(node: Node) -> bool:
return is_op(node, dist_ops)
def is_weight_node(node: Node) -> bool:
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
"""Get a user from a node if the node matches a given op set and num of users."""
if node is None:

View File

@ -353,10 +353,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mt
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5814203)
unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py::test_flashinfer_fused_moe_matches_torch_moe SKIP (https://nvbugs/5814215)
full:sm89/accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5814504)
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-False-8] SKIP (https://nvbugs/5819002)
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-True-8] SKIP (https://nvbugs/5819002)
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-False-8] SKIP (https://nvbugs/5819002)
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-True-8] SKIP (https://nvbugs/5819002)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5819005)
unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks SKIP (https://nvbugs/5819014)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5819019)

View File

@ -572,7 +572,9 @@ def _run_pattern_detection_job(
fused_weight_dims=None,
)
)
if len(node.args) > 1 and "norm_weight" in node.args[0].name:
if len(node.args) > 1 and (
"norm_weight" in node.args[0].name or "a_log" in node.args[0].name
):
expected_transformations.append(
WeightShardingInfo(
target_node=node.name,