diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index d86c8244a5..2ac5f72ba3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -17,8 +17,9 @@ from ...custom_ops.quantization.quant import ( from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import ( + WeightBiasInfoCache, + extract_weight_nodes, get_quantization_params_from_linear_node, - get_weight_info, is_bmm_op, is_linear_op, ) @@ -109,27 +110,28 @@ class Quantization(BaseTransform): factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - qcfg = factory.get_quant_config() - if not qcfg or ( - qcfg.get("quant_algo", "").upper() != self.algo_name - and qcfg.get("quant_method", "").upper() != self.algo_name - ): - return gm, TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) - excluded = qcfg.get("exclude_modules", []) - cnt = 0 - for n in gm.graph.nodes: - if not is_linear_op(n): - continue - if should_skip_quantization(n, excluded): - continue - self._insert_quantized_linear(gm, n, is_quantized_graph=False) - cnt += 1 + with WeightBiasInfoCache(): + qcfg = factory.get_quant_config() + if not qcfg or ( + qcfg.get("quant_algo", "").upper() != self.algo_name + and qcfg.get("quant_method", "").upper() != self.algo_name + ): + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + excluded = qcfg.get("exclude_modules", []) + cnt = 0 + for n in gm.graph.nodes: + if not is_linear_op(n): + continue + if should_skip_quantization(n, excluded): + continue + self._insert_quantized_linear(gm, n, is_quantized_graph=False) + cnt += 1 - return gm, TransformInfo( - skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=(cnt == 0) - ) + return gm, TransformInfo( + skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=(cnt == 0) + ) def _insert_quantized_linear( self, @@ -141,9 +143,10 @@ class Quantization(BaseTransform): The state_dict is also updated to contain the sharded weights. """ - lin_weight = get_weight_info(node) - if lin_weight is None: + weight_nodes = extract_weight_nodes(node) + if len(weight_nodes.weights) == 0: raise ValueError(f"Linear node {node.name} has no weight") + lin_weight = weight_nodes.weights[0] new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False) modname, _, attrname = lin_weight.node_key.rpartition(".") @@ -564,13 +567,15 @@ class FP8BMMQuantizationFromConfig(Quantization): excluded = qcfg.get("exclude_modules", []) cnt = 0 - for n in gm.graph.nodes: - if not is_bmm_op(n): - continue - if should_skip_quantization(n, excluded): - continue - if self._insert_quantized_bmm(gm, n, is_quantized_graph=False): - cnt += 1 + + with WeightBiasInfoCache(): + for n in gm.graph.nodes: + if not is_bmm_op(n): + continue + if should_skip_quantization(n, excluded): + continue + if self._insert_quantized_bmm(gm, n, is_quantized_graph=False): + cnt += 1 return gm, TransformInfo( skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=True diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 9d6139dfec..48c8053863 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -38,11 +38,12 @@ from ...utils.logger import ad_logger from ...utils.node_utils import ( LayerSubgraph, LayerType, + WeightBiasInfoCache, bfs, extract_weight_name, + extract_weight_nodes, filtered_nodes, get_all_layer_subgraphs, - get_all_weight_infos, get_all_weights_in_subgraph, is_any_attention_op, is_any_lin_op, @@ -821,35 +822,44 @@ class Sharding(BaseTransform): ) info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) - for source in config.sharding_source: - if source == ShardingSource.FACTORY: - if len(config.factory_config) == 0: - ad_logger.debug( - "No factory config found. Skipping sharding from factory config" + with WeightBiasInfoCache(): + for source in config.sharding_source: + if source == ShardingSource.FACTORY: + if len(config.factory_config) == 0: + ad_logger.debug( + "No factory config found. Skipping sharding from factory config" + ) + continue + ad_logger.info("Applying sharding from factory config") + info += detect_sharding_from_config( + gm, transform_container, ShardingSource.FACTORY + ) + elif source == ShardingSource.MANUAL: + if len(config.manual_config) == 0: + ad_logger.debug( + "No manual config found. Skipping sharding from manual config" + ) + continue + ad_logger.info("Applying sharding from manual config") + info += detect_sharding_from_config( + gm, transform_container, ShardingSource.MANUAL ) - continue - ad_logger.info("Applying sharding from factory config") - info += detect_sharding_from_config(gm, transform_container, ShardingSource.FACTORY) - elif source == ShardingSource.MANUAL: - if len(config.manual_config) == 0: - ad_logger.debug("No manual config found. Skipping sharding from manual config") - continue - ad_logger.info("Applying sharding from manual config") - info += detect_sharding_from_config(gm, transform_container, ShardingSource.MANUAL) - elif source == ShardingSource.HEURISTIC: - ad_logger.info(f"Running autodeploy sharding heuristics: {config.sharding_dims}") - # run TP sharding across ranks - if ShardingDim.TP in config.sharding_dims: - info += detect_column_row_shard(gm, transform_container) + elif source == ShardingSource.HEURISTIC: + ad_logger.info( + f"Running autodeploy sharding heuristics: {config.sharding_dims}" + ) + # run TP sharding across ranks + if ShardingDim.TP in config.sharding_dims: + info += detect_column_row_shard(gm, transform_container) - # run EP sharding across ranks - if ShardingDim.EP in config.sharding_dims: - info += detect_ep_shard(gm, transform_container) + # run EP sharding across ranks + if ShardingDim.EP in config.sharding_dims: + info += detect_ep_shard(gm, transform_container) - # run BMM sharding across ranks - if ShardingDim.BMM in config.sharding_dims: - info += detect_dp_bmm_shard(gm, transform_container) + # run BMM sharding across ranks + if ShardingDim.BMM in config.sharding_dims: + info += detect_dp_bmm_shard(gm, transform_container) return gm, info @@ -889,18 +899,19 @@ class ShardingTransformExecutor(BaseTransform): num_matches = 0 transforms = shared_config.sharding_transform_container - for tp_transform in transforms.weight_sharding_transforms: - if check_and_apply(tp_transform): - num_matches += 1 - for bmm_transform in transforms.bmm_transforms: - if check_and_apply(bmm_transform): - num_matches += 1 - for ep_transform in transforms.ep_transforms: - if check_and_apply(ep_transform): - num_matches += 1 - for rmsnorm_transform in transforms.rmsnorm_transforms: - if check_and_apply(rmsnorm_transform): - num_matches += 1 + with WeightBiasInfoCache(): + for tp_transform in transforms.weight_sharding_transforms: + if check_and_apply(tp_transform): + num_matches += 1 + for bmm_transform in transforms.bmm_transforms: + if check_and_apply(bmm_transform): + num_matches += 1 + for ep_transform in transforms.ep_transforms: + if check_and_apply(ep_transform): + num_matches += 1 + for rmsnorm_transform in transforms.rmsnorm_transforms: + if check_and_apply(rmsnorm_transform): + num_matches += 1 # post-sharding cleanup transformations for update_transform in transforms.parameter_update_transforms: @@ -1308,17 +1319,17 @@ def _shard_parameter_node( return # Shard weight using the unified function (also updates the parameter) - all_weight_infos = get_all_weight_infos(node) + weight_nodes = extract_weight_nodes(node) # Parametrized nodes must have at least one weight (for debugging) - assert len(all_weight_infos.weights) > 0, ( + assert len(weight_nodes.weights) > 0, ( f"Node {node.name} has no weights - weight mapping may be incorrect" ) - for weight_info in all_weight_infos.weights: + for weight_node in weight_nodes.weights: _, weight_new_shape = shard_weight_tensor( gm=gm, - weight_tensor=weight_info.tensor, - param_key=weight_info.node_key, + weight_tensor=weight_node.tensor, + param_key=weight_node.node_key, dim=dim, rank=rank, world_size=world_size, @@ -1328,22 +1339,22 @@ def _shard_parameter_node( if quantization_cb is not None: quantization_cb( gm=gm, - submod=weight_info.submod, + submod=weight_node.submod, node=node, - weight_key=weight_info.node_key, + weight_key=weight_node.node_key, weight_new_shape=weight_new_shape, dim=dim, rank=rank, world_size=world_size, ) - for bias_info in all_weight_infos.biases: + for bias_node in weight_nodes.biases: if dim == 0: # update bias for dim 0 --> we can handle it like the weight shard_weight_tensor( gm=gm, - weight_tensor=bias_info.tensor, - param_key=bias_info.node_key, + weight_tensor=bias_node.tensor, + param_key=bias_node.node_key, dim=dim, rank=rank, world_size=world_size, @@ -1358,10 +1369,10 @@ def _shard_parameter_node( args[2] = None node.args = tuple(args) gm.graph.erase_node(node_bias) - bias_param_name = bias_info.node_key.rpartition(".")[-1] - setattr(bias_info.submod, bias_param_name, None) + bias_param_name = bias_node.node_key.rpartition(".")[-1] + setattr(bias_node.submod, bias_param_name, None) gm._register_load_state_dict_pre_hook( - partial(_load_hook_remove, param_key=bias_info.node_key) + partial(_load_hook_remove, param_key=bias_node.node_key) ) # # # column shard with no gather: the output is sharded diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index efba02306f..ddf3474959 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -152,36 +152,133 @@ def get_all_weights_in_subgraph( def extract_weight_name(node: Node) -> Union[str, bool]: - """ - Extract the weight parameter name for a compute node. - - Args: - node: Compute node (linear, MoE, SSM, etc.) - - Returns: - Weight parameter name (str), or False if no weight exists. - """ - weight_node = get_weight_node(node) - if weight_node is None: + try: + weight_nodes = extract_weight_nodes(node) + except Exception: return False - return weight_node.target + if len(weight_nodes.weights) == 0: + return False + return weight_nodes.weights[0].node_key def get_param_or_buffer(tensor_name: str, gm: GraphModule) -> torch.Tensor: - if tensor_name in dict(gm.named_parameters()): - return gm.get_parameter(tensor_name) - elif tensor_name in dict(gm.named_buffers()): - return gm.get_buffer(tensor_name) - else: - raise KeyError(f"Tensor {tensor_name} not found in the graph") + param_dict = WeightBiasInfoCache.get_param_dict(gm) + if tensor_name in param_dict: + return param_dict[tensor_name] + buffer_dict = WeightBiasInfoCache.get_buffer_dict(gm) + if tensor_name in buffer_dict: + return buffer_dict[tensor_name] + raise KeyError(f"Tensor {tensor_name} not found in the graph") + + +class WeightBiasInfoCache: + """Cache for weight and bias information to avoid repeated expensive operations. + + This class manages caches for parameter names and weight shapes that are used + during graph transformation operations. Use it as a context manager to scope + the cache lifetime. + + Example: + with WeightBiasInfoCache() as cache: + # All calls to get_weight_shape and extract_weight_nodes + # within this block use caching + layer_subgraphs, _ = get_all_layer_subgraphs(gm) + # Caches are cleared here + """ + + # Class-level reference to the currently active cache instance + _active_instance: "WeightBiasInfoCache" = None + + def __init__(self): + # Cache for param/buffer dicts to avoid repeated expensive named_parameters/named_buffers calls + self._param_dict_cache = {} + self._buffer_dict_cache = {} + # Cache for get_weight_shape to avoid repeated expensive extract_weight_nodes calls + self._weight_shape_cache = {} + # Activate this cache instance + WeightBiasInfoCache._active_instance = self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def close(self): + """Explicitly deactivate and clear the cache.""" + if WeightBiasInfoCache._active_instance is self: + WeightBiasInfoCache._active_instance = None + self._param_dict_cache.clear() + self._buffer_dict_cache.clear() + self._weight_shape_cache.clear() + + def __del__(self): + """Cleanup when the cache is garbage collected.""" + self.close() + + @classmethod + def is_active(cls) -> bool: + """Check if caching is currently enabled.""" + return cls._active_instance is not None + + @classmethod + def get_param_dict(cls, gm: GraphModule) -> dict: + """Get cached parameters dict for a GraphModule, or compute and cache it.""" + if cls._active_instance is None: + return dict(gm.named_parameters()) + + cache = cls._active_instance._param_dict_cache + if gm not in cache: + cache[gm] = dict(gm.named_parameters()) + return cache[gm] + + @classmethod + def get_buffer_dict(cls, gm: GraphModule) -> dict: + """Get cached buffers dict for a GraphModule, or compute and cache it.""" + if cls._active_instance is None: + return dict(gm.named_buffers()) + + cache = cls._active_instance._buffer_dict_cache + if gm not in cache: + cache[gm] = dict(gm.named_buffers()) + return cache[gm] + + @classmethod + def get_param_names(cls, gm: GraphModule) -> set: + """Get cached parameter and buffer names for a GraphModule.""" + param_dict = cls.get_param_dict(gm) + buffer_dict = cls.get_buffer_dict(gm) + return set(param_dict.keys()).union(buffer_dict.keys()) + + @classmethod + def get_weight_shape(cls, node: Node) -> Tuple[bool, Optional[List[int]]]: + """Get cached weight shape for a node. + + Returns: + Tuple of (found, value). If found is False, value should be ignored. + """ + if cls._active_instance is None: + return False, None + + cache = cls._active_instance._weight_shape_cache + if node in cache: + return True, cache[node] + return False, None + + @classmethod + def set_weight_shape(cls, node: Node, shape: Optional[List[int]]): + """Store weight shape in cache.""" + if cls._active_instance is not None: + cls._active_instance._weight_shape_cache[node] = shape def extract_weight_nodes(node: Node) -> WeightNodes: """Extracts the list of weight node and optional bias node from the given parametrized node""" gm = node.graph.owning_module - param_names = {name for name, _ in gm.named_parameters()}.union( - {name for name, _ in gm.named_buffers()} - ) + + # Use cached param_names to avoid repeated expensive named_parameters/named_buffers calls + param_names = WeightBiasInfoCache.get_param_names(gm) def find_get_attr_node(weight_node: Node) -> Node: """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op.""" @@ -255,18 +352,21 @@ def extract_weight_nodes(node: Node) -> WeightNodes: return WeightNodes(weights=weight_nodes, biases=bias_nodes) +def get_weight_node(node: Node) -> Node: + """Get the primary weight node for a compute node.""" + weight_nodes = extract_weight_nodes(node) + if len(weight_nodes.weights) == 0: + raise ValueError(f"Node {node.name} has no weight") + return weight_nodes.weights[0].node + + def num_users_of_weight_node(node: Node) -> int: - """ - Get the number of users of the weight node. - - Args: - node: Compute node (linear, MoE, SSM, etc.) - - Returns: - Number of users of the primary weight node, or 0 if no weight exists. - """ - weight_node = get_weight_node(node) - return len(weight_node.users) if weight_node is not None else 0 + """Returns the number of users of the weight node of the given parametrized node.""" + try: + weight_node = get_weight_node(node) + except ValueError: + return 0 + return len(weight_node.users) def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket: @@ -469,167 +569,6 @@ def is_weight_node(node: Node) -> bool: return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0 -# Auxiliary ops that may appear between a weight node and its consumer compute node -_WEIGHT_AUX_OPS = frozenset( - { - torch.ops.aten.to.dtype, - torch.ops.aten.view.default, - } -) - - -def precompute_weight_node_mapping(gm: GraphModule) -> None: - """ - Pre-compute weight-to-consumer mapping for all weight nodes in the graph. - - For each weight node (get_attr), finds the consumer compute node by traversing - through auxiliary ops (to.dtype, view.default). Stores the mapping in consumer - node's metadata: - - node.meta["weight_nodes"]: list of weight nodes (non-bias) - - node.meta["bias_nodes"]: list of bias nodes - - This enables O(1) weight node lookup instead of O(depth) backward traversal. - Called automatically on first weight lookup via lazy initialization. - - GUARANTEES (verified by assertions for debugging): - - Called exactly once per GraphModule - - No duplicate weight/bias nodes in any consumer's lists - - Each weight node mapped to exactly one consumer - """ - # Early return if already computed - if "_weight_mapping_computed" in gm.meta: - return - gm.meta["_weight_mapping_computed"] = True - - for node in gm.graph.nodes: - if not is_weight_node(node): - continue - - is_bias = node.target.endswith("bias") - - # Find the consumer compute node by traversing through auxiliary ops - current = node - visited = {current} - - while True: - # Get users of current node - users = list(current.users.keys()) - if not users: - break - - # Check if any user is a compute node (not an auxiliary op) - consumer_found = None - aux_node = None - - for user in users: - if is_bias: - if "bias_nodes" not in user.meta: - user.meta["bias_nodes"] = [] - # ASSERTION: Each weight node should be mapped exactly once - assert node not in user.meta["bias_nodes"], ( - f"Duplicate bias node {node.name} found for consumer {user.name}" - ) - user.meta["bias_nodes"].append(node) - else: - if "weight_nodes" not in user.meta: - user.meta["weight_nodes"] = [] - # ASSERTION: Each weight node should be mapped exactly once - assert node not in user.meta["weight_nodes"], ( - f"Duplicate weight node {node.name} found for consumer {user.name}" - ) - user.meta["weight_nodes"].append(node) - if user.target in _WEIGHT_AUX_OPS: - # This is an auxiliary op, continue traversing - aux_node = user - else: - # This is a potential consumer compute node - consumer_found = user - break - - if consumer_found is not None: - # Found the consumer, return - break - elif aux_node is not None and aux_node not in visited: - # Continue through auxiliary op - current = aux_node - visited.add(current) - else: - # No more nodes to traverse - break - - -def _ensure_weight_mapping(node: Node) -> None: - """Ensure weight node mapping is computed. Lazily calls precompute if needed.""" - gm = node.graph.owning_module - if "_weight_mapping_computed" not in gm.meta or not gm.meta["_weight_mapping_computed"]: - precompute_weight_node_mapping(gm) - - -def get_weight_node(node: Node) -> Optional[Node]: - """Get the primary weight node for a compute node""" - _ensure_weight_mapping(node) - weight_nodes = node.meta.get("weight_nodes", []) - return weight_nodes[0] if weight_nodes else None - - -def get_weight_nodes(node: Node) -> List[Node]: - """Get all weight nodes for a compute node""" - _ensure_weight_mapping(node) - return node.meta.get("weight_nodes", []) - - -def get_bias_nodes(node: Node) -> List[Node]: - """Get all bias nodes for a compute node""" - _ensure_weight_mapping(node) - return node.meta.get("bias_nodes", []) - - -@dataclass -class WeightInfo: - """Lightweight weight info extracted from a weight node.""" - - node: Node - node_key: str - tensor: torch.Tensor - submod: nn.Module - - -def _weight_node_to_info(weight_node: Node, gm: GraphModule) -> WeightInfo: - """Convert a weight node to WeightInfo.""" - node_key = weight_node.target - tensor = get_param_or_buffer(node_key, gm) - submod = gm.get_submodule(node_key.rpartition(".")[0]) - return WeightInfo(node=weight_node, node_key=node_key, tensor=tensor, submod=submod) - - -def get_weight_info(node: Node) -> Optional[WeightInfo]: - """Extract weight info for the primary weight of a compute node.""" - weight_node = get_weight_node(node) - if weight_node is None: - return None - return _weight_node_to_info(weight_node, node.graph.owning_module) - - -@dataclass -class AllWeightInfos: - """Container for all weight and bias infos of a compute node.""" - - weights: List[WeightInfo] - biases: List[WeightInfo] - - -def get_all_weight_infos(node: Node) -> AllWeightInfos: - """Extract all weight and bias infos for a compute node.""" - gm = node.graph.owning_module - weight_nodes = get_weight_nodes(node) - bias_nodes = get_bias_nodes(node) - - return AllWeightInfos( - weights=[_weight_node_to_info(wn, gm) for wn in weight_nodes], - biases=[_weight_node_to_info(bn, gm) for bn in bias_nodes], - ) - - def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0): """Get a user from a node if the node matches a given op set and num of users.""" if node is None: @@ -703,7 +642,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[N """ Get subgraphs for all consecutive layers (attention, MLP, SSM, MoE) in the graph. - Pre-computes weight mappings and caches weight shapes for all linear nodes. + Caches weight shapes for all linear nodes using WeightBiasInfoCache. Each layer is contained between opening linear layers and a single closing linear layer. Assumptions: @@ -734,9 +673,6 @@ def get_all_layer_subgraphs(gm: GraphModule) -> tuple[List[LayerSubgraph], set[N layer_subgraphs = [] linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op)) - # Pre-compute weight-to-consumer mapping for O(1) weight node lookup - precompute_weight_node_mapping(gm) - # Cache weight shapes for all linear nodes for lin_node in linear_nodes: if "lin_node_shape" not in lin_node.meta: @@ -1006,19 +942,25 @@ def subgraph( def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]: - """Get weight shape for a linear operation node. Returns None if no weight.""" + """Get the shape of the weight node.""" if not is_any_lin_op(node): return None - weight_node = get_weight_node(node) - if weight_node is None: + # Try to get from cache first + found, s = WeightBiasInfoCache.get_weight_shape(node) + if not found: + # Not in cache or caching not enabled - compute the shape + s = list(shape(extract_weight_nodes(node).weights[0].node)) + if len(s) == 0: + s = None + elif is_fp4_op(node): + # FP4 weights are packed as uint8 type with 2 FP4 values per element + s[-1] *= 2 + # Store in cache if caching is enabled + WeightBiasInfoCache.set_weight_shape(node, s) + + if s is None: return None - - s = list(shape(weight_node)) - - if is_fp4_op(node): - # FP4 weights are packed as uint8 type with 2 FP4 values per element - s[-1] *= 2 if dim is None: return s else: @@ -1254,13 +1196,11 @@ def shape(node: Node) -> Tuple[int, ...]: def get_weight_tensor(node: Node) -> torch.Tensor: - """Extract the weight tensor from a compute node.""" - weight_node = get_weight_node(node) - if weight_node is None: + """Extract the weight tensor from a node within a GraphModule.""" + weight_nodes = extract_weight_nodes(node) + if len(weight_nodes.weights) == 0: raise ValueError(f"Node {node.name} has no weight") - - gm = node.graph.owning_module - return get_param_or_buffer(weight_node.target, gm) + return weight_nodes.weights[0].tensor def draw_graph(gm: GraphModule, filename: str):