[#10056][fix] AutoDeploy: Handle deletion of nested params in sharding (#10376)

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
Gal Hubara-Agam 2026-01-01 15:11:11 +02:00 committed by GitHub
parent 4868772ad7
commit 5845951538
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 64 additions and 26 deletions

View File

@ -31,31 +31,11 @@ from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import del_attr_by_name, get_attr_by_name, set_attr_by_name
from ...utils.logger import ad_logger
from ...utils.pattern_matcher import ADPatternMatcherPass
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
def _get_attr_by_name(obj, name):
for part in name.split("."):
obj = getattr(obj, part)
return obj
def _set_attr_by_name(obj, name, value):
parts = name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)
def _del_attr_by_name(obj, name):
parts = name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
delattr(obj, parts[-1])
_PATTERN_INPUT_NAME = "a_log_like"
@ -82,13 +62,13 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]:
new_param_name = param_name.replace("A_log", "A_fused")
try:
_get_attr_by_name(gm, new_param_name)
get_attr_by_name(gm, new_param_name)
return new_param_name
except AttributeError:
pass
try:
a_log = _get_attr_by_name(gm, param_name)
a_log = get_attr_by_name(gm, param_name)
except AttributeError:
ad_logger.warning(f"Could not find attribute {param_name} in gm.")
return None
@ -96,7 +76,7 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]:
with torch.no_grad():
a_fused = -torch.exp(a_log.float())
_set_attr_by_name(
set_attr_by_name(
gm,
new_param_name,
nn.Parameter(a_fused, requires_grad=False),
@ -120,7 +100,7 @@ def _remove_unused_a_log_params(gm: GraphModule) -> bool:
if not name.endswith("A_log") or name in used_a_log_targets:
return
try:
_del_attr_by_name(gm, name)
del_attr_by_name(gm, name)
removed = True
except AttributeError:
ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.")

View File

@ -33,6 +33,7 @@ from .....functional import AllReduceStrategy
from ...custom_ops.trtllm_dist import is_trtllm_op_available
from ...models.factory import ModelFactory, ShardingConfigSource
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import del_attr_by_name
from ...utils.logger import ad_logger
from ...utils.node_utils import (
LayerSubgraph,
@ -1467,7 +1468,12 @@ def _insert_sharded_moe(
for expert in (
w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove
):
delattr(gm, expert.target)
try:
del_attr_by_name(gm, expert.target)
except AttributeError:
ad_logger.warning(
f"Failed to delete unused parameter {expert.target} from GraphModule."
)
def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node:

View File

@ -401,3 +401,55 @@ def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
gm, output_node = get_output_node(model)
lm_head_node = get_lm_head_node(gm, output_node)
return get_weight_tensor(gm, lm_head_node)
def get_attr_by_name(obj, name):
"""Get an attribute specified by a dot-separated path on an object.
Args:
obj: The root object from which to resolve the attribute path.
name (str): Dot-separated attribute path (e.g., "a.b.c").
Returns:
The value of the resolved attribute.
Raises:
AttributeError: If any component in the path does not exist.
"""
for part in name.split("."):
obj = getattr(obj, part)
return obj
def set_attr_by_name(obj, name, value):
"""Set an attribute specified by a dot-separated path on an object.
Args:
obj: The root object on which to set the attribute.
name (str): Dot-separated attribute path (e.g., "a.b.c").
value: The value to assign to the target attribute.
Raises:
AttributeError: If any intermediate component in the path does not exist.
"""
parts = name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)
def del_attr_by_name(obj, name):
"""Delete an attribute specified by a dot-separated path from an object.
Args:
obj: The root object from which to delete the attribute.
name (str): Dot-separated attribute path (e.g., "a.b.c").
Raises:
AttributeError: If any intermediate component in the path does not exist
or if the final attribute does not exist.
"""
parts = name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
delattr(obj, parts[-1])