mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
parent
4868772ad7
commit
5845951538
@ -31,31 +31,11 @@ from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import del_attr_by_name, get_attr_by_name, set_attr_by_name
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
def _get_attr_by_name(obj, name):
|
||||
for part in name.split("."):
|
||||
obj = getattr(obj, part)
|
||||
return obj
|
||||
|
||||
|
||||
def _set_attr_by_name(obj, name, value):
|
||||
parts = name.split(".")
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], value)
|
||||
|
||||
|
||||
def _del_attr_by_name(obj, name):
|
||||
parts = name.split(".")
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
delattr(obj, parts[-1])
|
||||
|
||||
|
||||
_PATTERN_INPUT_NAME = "a_log_like"
|
||||
|
||||
|
||||
@ -82,13 +62,13 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]:
|
||||
|
||||
new_param_name = param_name.replace("A_log", "A_fused")
|
||||
try:
|
||||
_get_attr_by_name(gm, new_param_name)
|
||||
get_attr_by_name(gm, new_param_name)
|
||||
return new_param_name
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
a_log = _get_attr_by_name(gm, param_name)
|
||||
a_log = get_attr_by_name(gm, param_name)
|
||||
except AttributeError:
|
||||
ad_logger.warning(f"Could not find attribute {param_name} in gm.")
|
||||
return None
|
||||
@ -96,7 +76,7 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]:
|
||||
with torch.no_grad():
|
||||
a_fused = -torch.exp(a_log.float())
|
||||
|
||||
_set_attr_by_name(
|
||||
set_attr_by_name(
|
||||
gm,
|
||||
new_param_name,
|
||||
nn.Parameter(a_fused, requires_grad=False),
|
||||
@ -120,7 +100,7 @@ def _remove_unused_a_log_params(gm: GraphModule) -> bool:
|
||||
if not name.endswith("A_log") or name in used_a_log_targets:
|
||||
return
|
||||
try:
|
||||
_del_attr_by_name(gm, name)
|
||||
del_attr_by_name(gm, name)
|
||||
removed = True
|
||||
except AttributeError:
|
||||
ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.")
|
||||
|
||||
@ -33,6 +33,7 @@ from .....functional import AllReduceStrategy
|
||||
from ...custom_ops.trtllm_dist import is_trtllm_op_available
|
||||
from ...models.factory import ModelFactory, ShardingConfigSource
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import del_attr_by_name
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import (
|
||||
LayerSubgraph,
|
||||
@ -1467,7 +1468,12 @@ def _insert_sharded_moe(
|
||||
for expert in (
|
||||
w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove
|
||||
):
|
||||
delattr(gm, expert.target)
|
||||
try:
|
||||
del_attr_by_name(gm, expert.target)
|
||||
except AttributeError:
|
||||
ad_logger.warning(
|
||||
f"Failed to delete unused parameter {expert.target} from GraphModule."
|
||||
)
|
||||
|
||||
|
||||
def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node:
|
||||
|
||||
@ -401,3 +401,55 @@ def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
|
||||
gm, output_node = get_output_node(model)
|
||||
lm_head_node = get_lm_head_node(gm, output_node)
|
||||
return get_weight_tensor(gm, lm_head_node)
|
||||
|
||||
|
||||
def get_attr_by_name(obj, name):
|
||||
"""Get an attribute specified by a dot-separated path on an object.
|
||||
|
||||
Args:
|
||||
obj: The root object from which to resolve the attribute path.
|
||||
name (str): Dot-separated attribute path (e.g., "a.b.c").
|
||||
|
||||
Returns:
|
||||
The value of the resolved attribute.
|
||||
|
||||
Raises:
|
||||
AttributeError: If any component in the path does not exist.
|
||||
"""
|
||||
for part in name.split("."):
|
||||
obj = getattr(obj, part)
|
||||
return obj
|
||||
|
||||
|
||||
def set_attr_by_name(obj, name, value):
|
||||
"""Set an attribute specified by a dot-separated path on an object.
|
||||
|
||||
Args:
|
||||
obj: The root object on which to set the attribute.
|
||||
name (str): Dot-separated attribute path (e.g., "a.b.c").
|
||||
value: The value to assign to the target attribute.
|
||||
|
||||
Raises:
|
||||
AttributeError: If any intermediate component in the path does not exist.
|
||||
"""
|
||||
parts = name.split(".")
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], value)
|
||||
|
||||
|
||||
def del_attr_by_name(obj, name):
|
||||
"""Delete an attribute specified by a dot-separated path from an object.
|
||||
|
||||
Args:
|
||||
obj: The root object from which to delete the attribute.
|
||||
name (str): Dot-separated attribute path (e.g., "a.b.c").
|
||||
|
||||
Raises:
|
||||
AttributeError: If any intermediate component in the path does not exist
|
||||
or if the final attribute does not exist.
|
||||
"""
|
||||
parts = name.split(".")
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
delattr(obj, parts[-1])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user