diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 85f0e24a4e..2df3ed61b4 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -43,6 +43,7 @@ from ...utils.node_utils import ( extract_weight_nodes, filtered_nodes, get_all_layer_subgraphs, + get_all_weights_in_subgraph, get_layer_after_linear_node, is_any_attention_op, is_any_lin_op, @@ -1060,31 +1061,6 @@ def _resolve_tp_cls_from_node(node: Node): return WeightShardingInfo -def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int: - """Helper to get the first dimension size of an argument (Node or Tensor).""" - if isinstance(arg, torch.Tensor): - return arg.shape[0] - if isinstance(arg, Node): - if arg.op == "get_attr": - # Traverse attributes to find the tensor - obj = gm - for atom in arg.target.split("."): - obj = getattr(obj, atom) - return obj.shape[0] - if "val" in arg.meta: - return shape(arg)[0] - raise ValueError(f"Cannot determine shape[0] for {arg}") - - -def get_all_weights_in_subgraph( - sources: list[Node], - sinks: list[Node], -): - """Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks.""" - weight_nodes = subgraph(sources, sinks, include=lambda n: n.op == "get_attr") - return weight_nodes - - def init_process_grid_from_config( config: ShardingTransformConfig, ) -> Dict[ShardingDim, Dict[str, int]]: diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 720deab984..83cc46309b 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -143,8 +143,19 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node): return input_params, weight_params, output_params -def extract_weight_name(node: Node) -> str: +def get_all_weights_in_subgraph( + sources: list[Node], + sinks: list[Node], +): + """Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks.""" + weight_nodes = subgraph(sources, sinks, include=is_weight_node) + return weight_nodes + + +def extract_weight_name(node: Node) -> Union[str, bool]: weight_nodes = extract_weight_nodes(node) + if len(weight_nodes.weights) == 0: + return False return weight_nodes.weights[0].node_key @@ -431,6 +442,10 @@ def is_dist_op(node: Node) -> bool: return is_op(node, dist_ops) +def is_weight_node(node: Node) -> bool: + return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0 + + def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0): """Get a user from a node if the node matches a given op set and num of users.""" if node is None: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 126f131762..13e58d2eb8 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -353,10 +353,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mt stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5814203) unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py::test_flashinfer_fused_moe_matches_torch_moe SKIP (https://nvbugs/5814215) full:sm89/accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5814504) -unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-False-8] SKIP (https://nvbugs/5819002) -unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-True-8] SKIP (https://nvbugs/5819002) -unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-False-8] SKIP (https://nvbugs/5819002) -unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-True-8] SKIP (https://nvbugs/5819002) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5819005) unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks SKIP (https://nvbugs/5819014) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5819019) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index e226d85c61..170d9e9b5c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -572,7 +572,9 @@ def _run_pattern_detection_job( fused_weight_dims=None, ) ) - if len(node.args) > 1 and "norm_weight" in node.args[0].name: + if len(node.args) > 1 and ( + "norm_weight" in node.args[0].name or "a_log" in node.args[0].name + ): expected_transformations.append( WeightShardingInfo( target_node=node.name,