mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
parent
cfe53e7425
commit
76ec820465
@ -45,6 +45,9 @@ transforms:
|
||||
cache_config:
|
||||
# mamba_dtype: float32
|
||||
mamba_dtype: null
|
||||
gather_logits_before_lm_head:
|
||||
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
|
||||
enabled: true
|
||||
fuse_mamba_a_log:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
|
||||
@ -134,6 +134,10 @@ transforms:
|
||||
fuse_add_rms_norm:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
gather_logits_before_lm_head:
|
||||
stage: post_load_fusion
|
||||
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
|
||||
enabled: false
|
||||
############################################################################################
|
||||
# VISUALIZE GRAPH
|
||||
############################################################################################
|
||||
|
||||
@ -367,6 +367,10 @@ class SequenceInfo:
|
||||
of decode sequences.
|
||||
- cache_loc: [c_0, c_1, ..., c_{np-1}] where np is total number of pages allocated to describe
|
||||
all sequences in the batch. Each value is a page index in the cache.
|
||||
- logits_gather_indices: [g_0, g_1, ..., g_{s_total-1}]
|
||||
Gather indices used by the gather_logits_before_lm_head custom op to gather logits before the LM head.
|
||||
- logits_gather_info: [num_tokens_to_gather, gather_required]. Info for the
|
||||
gather_logits_before_lm_head custom op to gather logits before the LM head.
|
||||
- _gather_idx: [g_0, g_1, ..., g_{s_total-1}]
|
||||
Gather indices used by the overlap scheduler to reorder input tokens.
|
||||
- _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
|
||||
@ -480,6 +484,8 @@ class SequenceInfo:
|
||||
("slot_idx", self.max_batch_size, torch.long),
|
||||
("use_initial_states", self.max_batch_size, torch.bool),
|
||||
("batch_info", 3, torch.int),
|
||||
("logits_gather_indices", self.max_num_tokens, torch.long),
|
||||
("logits_gather_info", 2, torch.int),
|
||||
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
|
||||
("_gather_idx", self.max_num_tokens, torch.int),
|
||||
("_mask_scatter_indices", self.max_num_tokens, torch.int),
|
||||
@ -499,7 +505,7 @@ class SequenceInfo:
|
||||
self._active_args = ("input_ids", "position_ids")
|
||||
self._shapeable_args = ("input_ids", "position_ids")
|
||||
# Args that should be returned from host (pinned memory) instead of device in _named_args
|
||||
self._host_return_args = ("batch_info",)
|
||||
self._host_return_args = ("batch_info", "logits_gather_info")
|
||||
############################################################################################
|
||||
|
||||
# EXTRA TENSOR FIELDS ######################################################################
|
||||
@ -535,15 +541,17 @@ class SequenceInfo:
|
||||
# truncate to total tokens now, reshape, and return
|
||||
return tnsr[: self.total_num_tokens].view(bs, sl, *tnsr.shape[1:])
|
||||
|
||||
def _get_arg(self, name: str) -> torch.Tensor:
|
||||
"""Get the argument from the input buffer either on device or host."""
|
||||
if name in self._host_return_args:
|
||||
arg = self._input_buffer.get_host_view(name)
|
||||
else:
|
||||
arg = self._input_buffer.get_view(name)
|
||||
return self._shape_for_forward(arg) if name in self._shapeable_args else arg
|
||||
|
||||
def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
|
||||
# Build args dict, using host views for _host_return_args, device views otherwise
|
||||
args = {}
|
||||
for name in self._active_args:
|
||||
if name in self._host_return_args:
|
||||
view = self._input_buffer.get_host_view(name)
|
||||
else:
|
||||
view = self._input_buffer.get_view(name)
|
||||
args[name] = self._shape_for_forward(view) if name in self._shapeable_args else view
|
||||
args = {k: self._get_arg(k) for k in self._active_args}
|
||||
|
||||
# check other args to include
|
||||
if include_extra_args:
|
||||
@ -801,7 +809,11 @@ class SequenceInfo:
|
||||
|
||||
def set_generate_only_batch(self, batch_size: Optional[int] = None) -> None:
|
||||
"""Set an example sequence for generate-only batch."""
|
||||
self.set_example_sequence([[1]] * (batch_size or self.max_batch_size))
|
||||
batch_size = batch_size or self.max_batch_size
|
||||
self.set_example_sequence(
|
||||
[[1]] * batch_size,
|
||||
logits_gather_info=[batch_size, 0],
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the sequence information.
|
||||
@ -887,6 +899,8 @@ class SequenceInfo:
|
||||
last_page_len: Optional[Sequence[int]] = None,
|
||||
slot_idx: Optional[Sequence[int]] = None,
|
||||
use_initial_states: Optional[Sequence[bool]] = None,
|
||||
logits_gather_indices: Optional[Sequence[int]] = None,
|
||||
logits_gather_info: Optional[Sequence[int]] = None,
|
||||
_gather_idx: Optional[Sequence[int]] = None,
|
||||
_mask_scatter_indices: Optional[Sequence[int]] = None,
|
||||
**extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
|
||||
@ -917,6 +931,8 @@ class SequenceInfo:
|
||||
slot_idx: Slot index for each sequence in the batch.
|
||||
use_initial_states: Per-sequence boolean indicating if the initial states should be
|
||||
used. If None, auto-computed as (input_pos > 0).
|
||||
logits_gather_indices: Gather indices for the logits before/after the LM head.
|
||||
logits_gather_info: Info list containing [num_tokens_to_gather, gather_required].
|
||||
_gather_idx: Gather indices for the overlap scheduler to reorder input tokens.
|
||||
_mask_scatter_indices: Mask scatter indices for the overlap scheduler.
|
||||
extra_args: Extra arguments to be stored in the interface.
|
||||
@ -999,6 +1015,17 @@ class SequenceInfo:
|
||||
use_initial_states = [i_p > 0 for i_p in self.input_pos]
|
||||
self._store_arg("use_initial_states", use_initial_states)
|
||||
|
||||
# check for updated logits_gather_indices
|
||||
if logits_gather_indices is None:
|
||||
# default is to gather all logits
|
||||
logits_gather_indices = list(range(self.total_num_tokens))
|
||||
self._store_arg("logits_gather_indices", logits_gather_indices, force_copy=True)
|
||||
|
||||
# check for updated logits_gather_info
|
||||
if logits_gather_info is None:
|
||||
logits_gather_info = [len(logits_gather_indices), 1]
|
||||
self._store_arg("logits_gather_info", logits_gather_info, force_copy=True)
|
||||
|
||||
### UPDATE OVERLAP SCHEDULER METADATA ######################################################
|
||||
# check for updated _gather_idx
|
||||
if _gather_idx is not None:
|
||||
@ -1042,9 +1069,24 @@ class SequenceInfo:
|
||||
ungathered_input_ids, gather_ids_device, mask_scatter_indices_device, input_ids_device
|
||||
)
|
||||
|
||||
# TODO: remove once https://github.com/NVIDIA/TensorRT-LLM/issues/9878 is fixed and
|
||||
# logits gather is enabled by default (only keep squeeze_logits)
|
||||
@nvtx_range("ad_maybe_gather_logits")
|
||||
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Maybe gather the logits if logits have not been gathered yet."""
|
||||
num_tokens = logits.shape[0] * logits.shape[1]
|
||||
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
|
||||
if gather_required and num_tokens_to_gather < num_tokens:
|
||||
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
|
||||
logits,
|
||||
self._get_arg("logits_gather_indices"),
|
||||
self._get_arg("logits_gather_info"),
|
||||
)
|
||||
return logits.squeeze(int(self.is_generate))
|
||||
|
||||
@nvtx_range("ad_unnest_sequences")
|
||||
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
|
||||
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
|
||||
t_squeezed = t_nested.squeeze(int(self.is_generate))
|
||||
return list(torch.split(t_squeezed, self.seq_len))
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::gather_logits_before_lm_head", mutates_args=())
|
||||
def gather_logits_before_lm_head(
|
||||
hidden_states: torch.Tensor,
|
||||
logits_gather_indices: torch.Tensor, # long tensor
|
||||
logits_gather_info: torch.Tensor, # int tensor
|
||||
) -> torch.Tensor:
|
||||
"""Gather hidden states using logits_gather_indices before LM head.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states tensor [b, 1, hidden] or [1, s_total, hidden]
|
||||
logits_gather_indices: indices for gathering logits.
|
||||
logits_gather_info: info for gathering logits.
|
||||
Returns:
|
||||
Gathered and flattened hidden states [num_gathered_tokens, hidden]
|
||||
"""
|
||||
# final shape is [total_tokens, hidden] from [b, 1, hidden] or [1, total_tokens, hidden]
|
||||
is_decode_only = hidden_states.shape[1] == 1
|
||||
hidden_states = hidden_states.squeeze(int(is_decode_only))
|
||||
|
||||
# info object
|
||||
num_tokens_to_gather, gather_required = logits_gather_info.tolist()
|
||||
|
||||
if gather_required:
|
||||
out = hidden_states.index_select(0, logits_gather_indices[:num_tokens_to_gather])
|
||||
else:
|
||||
out = hidden_states.clone(memory_format=torch.contiguous_format)
|
||||
return out.unsqueeze(int(is_decode_only))
|
||||
|
||||
|
||||
@gather_logits_before_lm_head.register_fake
|
||||
def gather_logits_before_lm_head_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
logits_gather_indices: torch.Tensor,
|
||||
logits_gather_info: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# NOTE: shape is not correct in fake mode
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/9878
|
||||
return torch.empty_like(hidden_states)
|
||||
@ -10,6 +10,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
@ -319,7 +320,7 @@ class ADEngine(ModelEngine):
|
||||
return self.cache_seq_interface.device
|
||||
|
||||
@classmethod
|
||||
def build_from_config(cls, ad_config: LlmArgs):
|
||||
def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None):
|
||||
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
|
||||
|
||||
max_batch_size = ad_config.max_batch_size
|
||||
@ -360,6 +361,7 @@ class ADEngine(ModelEngine):
|
||||
seq_info,
|
||||
device,
|
||||
ad_config=ad_config,
|
||||
mapping=mapping,
|
||||
reporting_info=reporting_info,
|
||||
)
|
||||
|
||||
@ -370,6 +372,7 @@ class ADEngine(ModelEngine):
|
||||
seq_info: SequenceInfo,
|
||||
device: DeviceLikeType,
|
||||
ad_config: Optional[LlmArgs] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
reporting_info: ReportingInfo = ReportingInfo(),
|
||||
) -> None:
|
||||
"""Initialize the engine with model and sequence information."""
|
||||
@ -439,13 +442,20 @@ class ADEngine(ModelEngine):
|
||||
# keep a reference for one dummy request around
|
||||
self.padding_dummy_request: Optional[LlmRequest] = None
|
||||
|
||||
# Reuse _execute_logit_post_processors from PyTorchModelEngine
|
||||
self.mapping = mapping
|
||||
self._execute_logit_post_processors = types.MethodType(
|
||||
PyTorchModelEngine._execute_logit_post_processors, self
|
||||
)
|
||||
|
||||
@nvtx_range("ad_prepare_inputs")
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: ResourceManager,
|
||||
new_tokens: Optional[torch.Tensor] = None,
|
||||
) -> List[bool]:
|
||||
gather_context_logits: bool = False,
|
||||
) -> None:
|
||||
"""Prepare inputs for AD Model from scheduled requests."""
|
||||
# cache manager
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
@ -467,7 +477,6 @@ class ADEngine(ModelEngine):
|
||||
input_pos: List[int] = []
|
||||
seq_len: List[int] = []
|
||||
cu_seqlen: List[int] = [0]
|
||||
last_logit_only: List[bool] = []
|
||||
cache_loc: List[int] = []
|
||||
pages_per_seq: List[int] = []
|
||||
cu_num_pages: List[int] = [0]
|
||||
@ -481,6 +490,9 @@ class ADEngine(ModelEngine):
|
||||
mask_scatter_indices: List[int] = []
|
||||
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
|
||||
|
||||
# gather indices for logits
|
||||
logits_gather_indices: List[int] = []
|
||||
|
||||
page_size = self.cache_seq_interface.info.page_size
|
||||
dummy_token = -1
|
||||
num_ctx_requests = len(context_requests)
|
||||
@ -505,7 +517,11 @@ class ADEngine(ModelEngine):
|
||||
cu_seqlen.append(cu_seqlen[-1] + seq_len[-1])
|
||||
|
||||
request.py_batch_idx = request.seq_slot
|
||||
last_logit_only.append(True)
|
||||
|
||||
if gather_context_logits:
|
||||
logits_gather_indices.extend(range(cu_seqlen[-2], cu_seqlen[-1]))
|
||||
else:
|
||||
logits_gather_indices.append(cu_seqlen[-1] - 1)
|
||||
|
||||
# get cache indices and truncate the number of blocks according to end_compute
|
||||
cache_indices = kv_cache_manager.get_cache_indices(request)
|
||||
@ -600,11 +616,13 @@ class ADEngine(ModelEngine):
|
||||
request.py_batch_idx = request.seq_slot
|
||||
slot_idx.append(request.seq_slot)
|
||||
use_initial_states.append(input_pos[-1] > 0)
|
||||
last_logit_only.append(False)
|
||||
|
||||
seq_len.append(len(input_ids[-1]))
|
||||
cu_seqlen.append(cu_seqlen[-1] + seq_len[-1])
|
||||
|
||||
# for generate requests, we always keep all logits (target logits + draft logits)
|
||||
logits_gather_indices.extend(range(cu_seqlen[-2], cu_seqlen[-1]))
|
||||
|
||||
if use_overlap:
|
||||
mask_scatter_indices.extend(list(range(cu_seqlen[-2], cu_seqlen[-1])))
|
||||
|
||||
@ -618,6 +636,14 @@ class ADEngine(ModelEngine):
|
||||
|
||||
position_ids.append(list(range(input_pos[-1], seq_len_with_cache[-1])))
|
||||
|
||||
# check for logits_gather_info
|
||||
# we only need to gather in the following situation:
|
||||
# 1. there are context requests and
|
||||
# 2. we are not gathering context logits
|
||||
# In other cases (decode-only) or when we keep all logits, we do not need to gather.
|
||||
gather_required = len(context_requests) > 0 and not gather_context_logits
|
||||
logits_gather_info = [len(logits_gather_indices), int(gather_required)]
|
||||
|
||||
# update the sequence info object now
|
||||
self.cache_seq_interface.info.nest_sequences(
|
||||
input_ids,
|
||||
@ -632,6 +658,8 @@ class ADEngine(ModelEngine):
|
||||
last_page_len=last_page_len,
|
||||
slot_idx=slot_idx,
|
||||
use_initial_states=use_initial_states,
|
||||
logits_gather_indices=logits_gather_indices,
|
||||
logits_gather_info=logits_gather_info,
|
||||
_gather_idx=None if new_tokens is None else flat_gather_indices,
|
||||
_mask_scatter_indices=None if new_tokens is None else mask_scatter_indices,
|
||||
**extra_args,
|
||||
@ -644,18 +672,15 @@ class ADEngine(ModelEngine):
|
||||
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
|
||||
# TODO: handle extend requests and draft requests for specdec
|
||||
self.iter_states["num_generation_tokens"] = num_generation_tokens
|
||||
return last_logit_only
|
||||
|
||||
@nvtx_range("ad_compute_logits")
|
||||
def _compute_logits(self) -> List[torch.Tensor]:
|
||||
# run the model
|
||||
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
|
||||
logits = self.cache_seq_interface.info.maybe_gather_and_squeeze_logits(logits)
|
||||
|
||||
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
|
||||
logits = logits.float()
|
||||
|
||||
# return a list of tensors
|
||||
return self.cache_seq_interface.info.unnest_sequences(logits)
|
||||
return logits.float()
|
||||
|
||||
def get_max_num_sequences(self) -> int:
|
||||
"""Maximum number of sequences supported by the engine."""
|
||||
@ -674,19 +699,18 @@ class ADEngine(ModelEngine):
|
||||
"""Run forward from scheduled requests; main entrypoint that gets called by the executor."""
|
||||
# convert requests and store in sequence info object
|
||||
new_tokens = getattr(new_tensors_device, "new_tokens", None)
|
||||
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens)
|
||||
self._prepare_inputs(
|
||||
scheduled_requests, resource_manager, new_tokens, gather_context_logits
|
||||
)
|
||||
self.iter_counter += 1
|
||||
|
||||
# compute all logits
|
||||
logits = self._compute_logits()
|
||||
outputs = {
|
||||
"logits": self._compute_logits(),
|
||||
}
|
||||
if self.mapping is not None:
|
||||
self._execute_logit_post_processors(scheduled_requests, outputs)
|
||||
|
||||
# gather+cat logits
|
||||
logits_flat = torch.cat(
|
||||
[ls_one_seq[-last_only:] for ls_one_seq, last_only in zip(logits, last_logit_only)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return {"logits": logits_flat}
|
||||
return outputs
|
||||
|
||||
|
||||
def create_draft_model_engine_maybe(
|
||||
@ -825,7 +849,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
)
|
||||
|
||||
# initialize model engine
|
||||
engine = ADEngine.build_from_config(ad_config=ad_config)
|
||||
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
|
||||
|
||||
spec_config = ad_config.speculative_config
|
||||
if spec_config is not None and not spec_config.spec_dec_mode.is_draft_target():
|
||||
|
||||
@ -137,7 +137,7 @@ class DemoEngine(ADEngine):
|
||||
context_logits: Optional[List[torch.Tensor]] = None
|
||||
|
||||
def _generate_single_step(idx: int):
|
||||
logits = self._compute_logits()
|
||||
logits = sequence_info.unnest_sequences(self._compute_logits())
|
||||
logits_last = torch.stack([l_one_seq[-1] for l_one_seq in logits]).float().unsqueeze(1)
|
||||
|
||||
token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1]
|
||||
|
||||
@ -12,11 +12,12 @@ from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
|
||||
|
||||
import torch.nn as nn
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..utils._graph import (
|
||||
add_graph_input,
|
||||
canonicalize_graph,
|
||||
lift_to_meta,
|
||||
named_graphmodules,
|
||||
@ -505,6 +506,19 @@ class BaseTransform(ABC):
|
||||
f"Transform {self.get_transform_key()} only supports `run_per_gm=True`."
|
||||
)
|
||||
|
||||
def _add_or_retrieve_input(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, name: str
|
||||
) -> Node:
|
||||
"""Add or retrieve an input node from the graph."""
|
||||
input_nodes = gm.graph.find_nodes(op="placeholder", target=name)
|
||||
if len(input_nodes) == 0:
|
||||
cm.info.activate_arg(name)
|
||||
return add_graph_input(gm, name)
|
||||
elif len(input_nodes) == 1:
|
||||
return input_nodes[0]
|
||||
else:
|
||||
raise ValueError(f"Expected exactly one input node for {name=}, got {input_nodes=}")
|
||||
|
||||
|
||||
class TransformRegistry:
|
||||
"""A registry for all transforms."""
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Transform to gather hidden states before LM head using logits_gather_mask.
|
||||
|
||||
This moves the gather operation into the model graph before the LM head,
|
||||
enabling CUDA graph capture and reducing computation by only computing
|
||||
logits for the tokens that are actually needed.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...utils.node_utils import is_linear_op, is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@TransformRegistry.register("gather_logits_before_lm_head")
|
||||
class GatherLogitsBeforeLmHeadTransform(BaseTransform):
|
||||
"""Transform to gather hidden states before LM head using logits_gather_mask.
|
||||
|
||||
This transform inserts a gather operation before the LM head linear layer
|
||||
to select only the hidden states that need logits computed. The output is always
|
||||
[b, hidden_size] in decode-only for CUDA graph compatibility.
|
||||
|
||||
Benefits:
|
||||
- Reduces computation by only computing logits for needed tokens
|
||||
- Eliminates Python loop overhead
|
||||
- Enables CUDA graph capture of the gather
|
||||
- Moves gather into the graph for better optimization
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm,
|
||||
factory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
self._log_info("Applying GatherLogitsBeforeLmHead transform...")
|
||||
|
||||
# assume lm head node is the input to the output node
|
||||
lm_head_node = gm.graph.find_nodes(op="output")[0].all_input_nodes[0]
|
||||
if is_op(lm_head_node, torch.ops.aten.to):
|
||||
lm_head_node = lm_head_node.all_input_nodes[0]
|
||||
|
||||
if is_linear_op(lm_head_node):
|
||||
node_to_gather = lm_head_node.all_input_nodes[0]
|
||||
self._log_info(f"Found LM head node: {lm_head_node.name}")
|
||||
else:
|
||||
node_to_gather = lm_head_node
|
||||
self._log_info("lm_head node is not linear, using it as the node to gather")
|
||||
|
||||
# Add logits_gather_mask as input in the graph and the sequence info interface
|
||||
logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "logits_gather_indices")
|
||||
logits_gather_info_node = self._add_or_retrieve_input(gm, cm, "logits_gather_info")
|
||||
|
||||
with gm.graph.inserting_after(node_to_gather):
|
||||
gathered_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.gather_logits_before_lm_head.default,
|
||||
args=(node_to_gather, logits_gather_indices_node, logits_gather_info_node),
|
||||
)
|
||||
node_to_gather.replace_all_uses_with(gathered_node)
|
||||
gathered_node.replace_input_with(gathered_node, node_to_gather)
|
||||
|
||||
return gm, TransformInfo(skipped=False, num_matches=1)
|
||||
@ -54,19 +54,6 @@ class InsertCachedAttention(BaseTransform):
|
||||
def attn_descriptor(self) -> Type[AttentionDescriptor]:
|
||||
return AttentionRegistry.get(self.config.backend)
|
||||
|
||||
def _add_or_retrieve_input(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, name: str
|
||||
) -> Node:
|
||||
"""Add or retrieve an input node from the graph."""
|
||||
input_nodes = gm.graph.find_nodes(op="placeholder", target=name)
|
||||
if len(input_nodes) == 0:
|
||||
cm.info.activate_arg(name)
|
||||
return add_graph_input(gm, name)
|
||||
elif len(input_nodes) == 1:
|
||||
return input_nodes[0]
|
||||
else:
|
||||
raise ValueError(f"Expected exactly one input node for {name=}, got {input_nodes=}")
|
||||
|
||||
def _process_metadata_std(self, gm: GraphModule, cm: CachedSequenceInterface) -> List[Node]:
|
||||
"""Process the standard metadata nodes."""
|
||||
return [
|
||||
|
||||
@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
|
||||
sequence_info.reset()
|
||||
sequence_info.nest_sequences(input_ids)
|
||||
logits = engine._compute_logits()
|
||||
logits = torch.stack(logits)
|
||||
assert logits is not None, "Logits are None"
|
||||
|
||||
mock_input = None
|
||||
@ -105,7 +104,6 @@ def test_demo_engine_sampling(attn_page_size: int):
|
||||
sequence_info.reset()
|
||||
sequence_info.nest_sequences(input_ids)
|
||||
logits = engine._compute_logits()
|
||||
logits = torch.stack(logits)
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
sampling_params = SamplingParams(top_k=5, temperature=1.0)
|
||||
@ -199,7 +197,7 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
|
||||
# No-chunk: whole prompt in one request
|
||||
req_full = _DummyRequest(tokens=tokens, begin=0, size=len(tokens), seq_slot=0)
|
||||
scheduled_full = SimpleNamespace(context_requests=[req_full], generation_requests=[])
|
||||
logits_full = engine.forward(scheduled_full, resource_manager)["logits"]
|
||||
logits_full_last = engine.forward(scheduled_full, resource_manager)["logits"][-1]
|
||||
|
||||
# Chunked: split into two context chunks
|
||||
split = len(tokens) // 2
|
||||
@ -211,7 +209,6 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
|
||||
|
||||
# Run first chunk (ignored output), then compare second chunk logits to full
|
||||
_ = engine.forward(scheduled_part1, resource_manager)
|
||||
logits_chunked_last = engine.forward(scheduled_part2, resource_manager)["logits"]
|
||||
logits_chunked_last = engine.forward(scheduled_part2, resource_manager)["logits"][-1]
|
||||
|
||||
assert logits_full.shape == logits_chunked_last.shape
|
||||
assert torch.allclose(logits_full, logits_chunked_last, atol=1e-5)
|
||||
torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5)
|
||||
|
||||
@ -0,0 +1,326 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for gather_logits_before_lm_head transform."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
|
||||
# Import to register custom op
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
class SimpleLMHeadModel(torch.nn.Module):
|
||||
"""Simple model with LM head for testing."""
|
||||
|
||||
def __init__(self, hidden_size: int = 128, vocab_size: int = 1000):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16)
|
||||
self.lm_head = torch.nn.Linear(hidden_size, vocab_size, device="cuda", dtype=torch.float16)
|
||||
|
||||
def forward(self, hidden_states, logit_gather_ids=None, seq_len=None):
|
||||
# Simulate transformer output
|
||||
hidden_states = self.linear1(hidden_states)
|
||||
# LM head
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
class TestGatherLogitsBeforeLmHeadOp:
|
||||
"""Test the custom op directly."""
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 8])
|
||||
def test_generate_format(self, batch_size):
|
||||
"""Test gather op with generate format input [batch, 1, hidden]."""
|
||||
hidden_size = 128
|
||||
hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Create gather info: num_tokens_to_gather=batch_size, gather_required=0 (False)
|
||||
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
|
||||
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states, logits_gather_indices, logits_gather_info
|
||||
)
|
||||
|
||||
# Should return [batch, 1, hidden] for generate format (3D shape preserved)
|
||||
assert output.shape == (batch_size, 1, hidden_size)
|
||||
assert output.dtype == hidden_states.dtype
|
||||
assert output.device == hidden_states.device
|
||||
torch.testing.assert_close(output, hidden_states)
|
||||
|
||||
@pytest.mark.parametrize("total_tokens", [10, 50, 100])
|
||||
def test_packed_format(self, total_tokens):
|
||||
"""Test gather op with packed format input [1, total_tokens, hidden]."""
|
||||
hidden_size = 128
|
||||
max_batch_size = 8
|
||||
hidden_states = torch.randn(
|
||||
1, total_tokens, hidden_size, device="cuda", dtype=torch.float16
|
||||
)
|
||||
|
||||
# Create gather indices: gather tokens at indices [0, 5, 10, ...] up to max_batch_size
|
||||
num_gather = min(total_tokens, max_batch_size)
|
||||
gather_indices = torch.arange(0, num_gather, dtype=torch.long, device="cuda")
|
||||
|
||||
# Create gather info: num_tokens_to_gather=num_gather, gather_required=1 (True)
|
||||
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
|
||||
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states, gather_indices, logits_gather_info
|
||||
)
|
||||
|
||||
# Should return [1, num_gather, hidden] for packed format (3D shape preserved)
|
||||
assert output.shape == (1, num_gather, hidden_size)
|
||||
assert output.dtype == hidden_states.dtype
|
||||
assert output.device == hidden_states.device
|
||||
|
||||
# Verify gathered values match expected indices
|
||||
expected = hidden_states[:, gather_indices, :]
|
||||
torch.testing.assert_close(output, expected)
|
||||
|
||||
def test_fake_implementation_generate_format(self):
|
||||
"""Test fake implementation for generate format."""
|
||||
batch_size = 4
|
||||
hidden_size = 128
|
||||
hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Create gather info
|
||||
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
|
||||
|
||||
# Use fake implementation directly
|
||||
with FakeTensorMode() as mode:
|
||||
hidden_states_fake = mode.from_tensor(hidden_states)
|
||||
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
|
||||
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
|
||||
)
|
||||
|
||||
# Should return [batch, 1, hidden_size] (fake returns empty_like which preserves 3D shape)
|
||||
assert output.shape == (batch_size, 1, hidden_size)
|
||||
assert output.dtype == hidden_states.dtype
|
||||
assert output.device == hidden_states.device
|
||||
|
||||
def test_fake_implementation_packed_format(self):
|
||||
"""Test fake implementation for packed format."""
|
||||
total_tokens = 50
|
||||
hidden_size = 128
|
||||
num_gather = 8
|
||||
hidden_states = torch.randn(
|
||||
1, total_tokens, hidden_size, device="cuda", dtype=torch.float16
|
||||
)
|
||||
|
||||
# Create gather info
|
||||
logits_gather_indices = torch.arange(num_gather, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
|
||||
|
||||
# Use fake implementation directly
|
||||
with FakeTensorMode() as mode:
|
||||
hidden_states_fake = mode.from_tensor(hidden_states)
|
||||
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
|
||||
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
|
||||
)
|
||||
|
||||
# The fake implementation returns empty_like which preserves input shape [1, total_tokens, hidden]
|
||||
assert output.shape == (1, total_tokens, hidden_size)
|
||||
assert output.dtype == hidden_states.dtype
|
||||
assert output.device == hidden_states.device
|
||||
|
||||
|
||||
class TestGatherLogitsBeforeLmHeadTransform:
|
||||
"""Test the transform application."""
|
||||
|
||||
def _create_cached_sequence_interface(self, max_batch_size: int = 8, device: str = "cuda"):
|
||||
"""Create a mock CachedSequenceInterface for testing."""
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=64,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=1024,
|
||||
)
|
||||
seq_info.to(device)
|
||||
return CachedSequenceInterface(seq_info, device=device)
|
||||
|
||||
def _check_gather_op_in_graph(self, gm: GraphModule) -> bool:
|
||||
"""Check if gather_logits_before_lm_head op is in the graph."""
|
||||
return any(
|
||||
is_op(n, torch.ops.auto_deploy.gather_logits_before_lm_head) for n in gm.graph.nodes
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 8])
|
||||
def test_transform_generate_format(self, batch_size):
|
||||
"""Test transform with generate format input."""
|
||||
hidden_size = 128
|
||||
vocab_size = 1000
|
||||
model = SimpleLMHeadModel(hidden_size, vocab_size).cuda()
|
||||
|
||||
# Create input in generate format [batch, 1, hidden]
|
||||
hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16)
|
||||
max_batch_size = 8
|
||||
logit_gather_ids = torch.zeros(max_batch_size, dtype=torch.long, device="cuda")
|
||||
seq_len = torch.ones(batch_size, dtype=torch.long, device="cuda")
|
||||
|
||||
# Export model
|
||||
# When batch_size=1, torch.export specializes it to a constant, so we skip dynamic shapes
|
||||
# For batch_size > 1, we use dynamic shapes to test the transform with varying batch sizes
|
||||
if batch_size == 1:
|
||||
dynamic_shapes = None
|
||||
else:
|
||||
# dynamic_shapes should be a tuple matching the number of positional args
|
||||
dynamic_shapes = (
|
||||
{0: Dim("batch_size", min=1, max=max_batch_size)}, # hidden_states
|
||||
None, # logit_gather_ids (static)
|
||||
None, # seq_len (static)
|
||||
)
|
||||
gm = torch_export_to_gm(
|
||||
model,
|
||||
args=(hidden_states, logit_gather_ids, seq_len),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
clone=True,
|
||||
)
|
||||
|
||||
# Apply transform
|
||||
cm = self._create_cached_sequence_interface(max_batch_size)
|
||||
transform_config = {
|
||||
"gather_logits_before_lm_head": {
|
||||
"stage": "post_load_fusion",
|
||||
"max_batch_size": max_batch_size,
|
||||
}
|
||||
}
|
||||
optimizer = InferenceOptimizer(None, transform_config)
|
||||
gm_transformed = optimizer(cm, gm)
|
||||
|
||||
# Check that gather op was inserted
|
||||
assert self._check_gather_op_in_graph(gm_transformed), "Gather op not found in graph"
|
||||
|
||||
# Test forward pass
|
||||
# We must pass the new graph inputs manually since we are running the graph directly
|
||||
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
|
||||
output = gm_transformed(
|
||||
hidden_states,
|
||||
logit_gather_ids,
|
||||
seq_len,
|
||||
logits_gather_indices=logits_gather_indices,
|
||||
logits_gather_info=logits_gather_info,
|
||||
)
|
||||
# Output should be [batch_size, 1, vocab_size] since gather now returns 3D
|
||||
assert output.shape == (batch_size, 1, vocab_size)
|
||||
|
||||
@pytest.mark.parametrize("total_tokens", [10, 50])
|
||||
def test_transform_packed_format(self, total_tokens):
|
||||
"""Test transform with packed format input."""
|
||||
hidden_size = 128
|
||||
vocab_size = 1000
|
||||
max_batch_size = 8
|
||||
model = SimpleLMHeadModel(hidden_size, vocab_size).cuda()
|
||||
|
||||
# Create input in packed format [1, total_tokens, hidden]
|
||||
hidden_states = torch.randn(
|
||||
1, total_tokens, hidden_size, device="cuda", dtype=torch.float16
|
||||
)
|
||||
logit_gather_ids = torch.arange(
|
||||
0, min(total_tokens, max_batch_size), dtype=torch.long, device="cuda"
|
||||
)
|
||||
# Pad to max_batch_size
|
||||
logit_gather_ids_padded = torch.zeros(max_batch_size, dtype=torch.long, device="cuda")
|
||||
logit_gather_ids_padded[: len(logit_gather_ids)] = logit_gather_ids
|
||||
|
||||
seq_len = torch.ones(max_batch_size, dtype=torch.long, device="cuda")
|
||||
seq_len[: len(logit_gather_ids)] = torch.ones(
|
||||
len(logit_gather_ids), dtype=torch.long, device="cuda"
|
||||
)
|
||||
|
||||
# Export model
|
||||
gm = torch_export_to_gm(
|
||||
model,
|
||||
args=(hidden_states, logit_gather_ids_padded, seq_len),
|
||||
dynamic_shapes=None,
|
||||
clone=True,
|
||||
)
|
||||
|
||||
# Apply transform
|
||||
cm = self._create_cached_sequence_interface(max_batch_size)
|
||||
transform_config = {
|
||||
"gather_logits_before_lm_head": {
|
||||
"stage": "post_load_fusion",
|
||||
"max_batch_size": max_batch_size,
|
||||
}
|
||||
}
|
||||
optimizer = InferenceOptimizer(None, transform_config)
|
||||
gm_transformed = optimizer(cm, gm)
|
||||
|
||||
# Check that gather op was inserted
|
||||
assert self._check_gather_op_in_graph(gm_transformed), "Gather op not found in graph"
|
||||
|
||||
# Test forward pass
|
||||
# We must pass the new graph inputs manually since we are running the graph directly
|
||||
num_gather = len(logit_gather_ids)
|
||||
logits_gather_indices = logit_gather_ids
|
||||
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
|
||||
output = gm_transformed(
|
||||
hidden_states,
|
||||
logit_gather_ids_padded,
|
||||
seq_len,
|
||||
logits_gather_indices=logits_gather_indices,
|
||||
logits_gather_info=logits_gather_info,
|
||||
)
|
||||
# Output should be [1, num_gather, vocab_size] since gather now returns 3D
|
||||
assert output.shape == (1, num_gather, vocab_size)
|
||||
|
||||
def test_transform_skips_when_disabled(self):
|
||||
"""Test that transform skips when disabled."""
|
||||
hidden_size = 128
|
||||
vocab_size = 1000
|
||||
model = SimpleLMHeadModel(hidden_size, vocab_size).cuda()
|
||||
|
||||
hidden_states = torch.randn(4, 1, hidden_size, device="cuda", dtype=torch.float16)
|
||||
max_batch_size = 8
|
||||
logit_gather_ids = torch.zeros(max_batch_size, dtype=torch.long, device="cuda")
|
||||
seq_len = torch.ones(4, dtype=torch.long, device="cuda")
|
||||
|
||||
# Export model
|
||||
gm = torch_export_to_gm(
|
||||
model,
|
||||
args=(hidden_states, logit_gather_ids, seq_len),
|
||||
dynamic_shapes=None,
|
||||
clone=True,
|
||||
)
|
||||
|
||||
# Apply transform with disabled config
|
||||
cm = self._create_cached_sequence_interface(max_batch_size)
|
||||
transform_config = {
|
||||
"gather_logits_before_lm_head": {
|
||||
"stage": "post_load_fusion",
|
||||
"enabled": False,
|
||||
"max_batch_size": max_batch_size,
|
||||
}
|
||||
}
|
||||
optimizer = InferenceOptimizer(None, transform_config)
|
||||
gm_transformed = optimizer(cm, gm)
|
||||
|
||||
# Check that gather op was NOT inserted
|
||||
assert not self._check_gather_op_in_graph(gm_transformed), (
|
||||
"Gather op should not be in graph"
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user