[#7532][feat] AutoDeploy: gather logits before lm head (#9962)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2025-12-17 22:50:13 -05:00 committed by GitHub
parent cfe53e7425
commit 76ec820465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 570 additions and 52 deletions

View File

@ -45,6 +45,9 @@ transforms:
cache_config:
# mamba_dtype: float32
mamba_dtype: null
gather_logits_before_lm_head:
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
enabled: true
fuse_mamba_a_log:
stage: post_load_fusion
enabled: true

View File

@ -134,6 +134,10 @@ transforms:
fuse_add_rms_norm:
stage: post_load_fusion
enabled: true
gather_logits_before_lm_head:
stage: post_load_fusion
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
enabled: false
############################################################################################
# VISUALIZE GRAPH
############################################################################################

View File

@ -367,6 +367,10 @@ class SequenceInfo:
of decode sequences.
- cache_loc: [c_0, c_1, ..., c_{np-1}] where np is total number of pages allocated to describe
all sequences in the batch. Each value is a page index in the cache.
- logits_gather_indices: [g_0, g_1, ..., g_{s_total-1}]
Gather indices used by the gather_logits_before_lm_head custom op to gather logits before the LM head.
- logits_gather_info: [num_tokens_to_gather, gather_required]. Info for the
gather_logits_before_lm_head custom op to gather logits before the LM head.
- _gather_idx: [g_0, g_1, ..., g_{s_total-1}]
Gather indices used by the overlap scheduler to reorder input tokens.
- _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
@ -480,6 +484,8 @@ class SequenceInfo:
("slot_idx", self.max_batch_size, torch.long),
("use_initial_states", self.max_batch_size, torch.bool),
("batch_info", 3, torch.int),
("logits_gather_indices", self.max_num_tokens, torch.long),
("logits_gather_info", 2, torch.int),
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
("_gather_idx", self.max_num_tokens, torch.int),
("_mask_scatter_indices", self.max_num_tokens, torch.int),
@ -499,7 +505,7 @@ class SequenceInfo:
self._active_args = ("input_ids", "position_ids")
self._shapeable_args = ("input_ids", "position_ids")
# Args that should be returned from host (pinned memory) instead of device in _named_args
self._host_return_args = ("batch_info",)
self._host_return_args = ("batch_info", "logits_gather_info")
############################################################################################
# EXTRA TENSOR FIELDS ######################################################################
@ -535,15 +541,17 @@ class SequenceInfo:
# truncate to total tokens now, reshape, and return
return tnsr[: self.total_num_tokens].view(bs, sl, *tnsr.shape[1:])
def _get_arg(self, name: str) -> torch.Tensor:
"""Get the argument from the input buffer either on device or host."""
if name in self._host_return_args:
arg = self._input_buffer.get_host_view(name)
else:
arg = self._input_buffer.get_view(name)
return self._shape_for_forward(arg) if name in self._shapeable_args else arg
def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
# Build args dict, using host views for _host_return_args, device views otherwise
args = {}
for name in self._active_args:
if name in self._host_return_args:
view = self._input_buffer.get_host_view(name)
else:
view = self._input_buffer.get_view(name)
args[name] = self._shape_for_forward(view) if name in self._shapeable_args else view
args = {k: self._get_arg(k) for k in self._active_args}
# check other args to include
if include_extra_args:
@ -801,7 +809,11 @@ class SequenceInfo:
def set_generate_only_batch(self, batch_size: Optional[int] = None) -> None:
"""Set an example sequence for generate-only batch."""
self.set_example_sequence([[1]] * (batch_size or self.max_batch_size))
batch_size = batch_size or self.max_batch_size
self.set_example_sequence(
[[1]] * batch_size,
logits_gather_info=[batch_size, 0],
)
def reset(self) -> None:
"""Reset the sequence information.
@ -887,6 +899,8 @@ class SequenceInfo:
last_page_len: Optional[Sequence[int]] = None,
slot_idx: Optional[Sequence[int]] = None,
use_initial_states: Optional[Sequence[bool]] = None,
logits_gather_indices: Optional[Sequence[int]] = None,
logits_gather_info: Optional[Sequence[int]] = None,
_gather_idx: Optional[Sequence[int]] = None,
_mask_scatter_indices: Optional[Sequence[int]] = None,
**extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
@ -917,6 +931,8 @@ class SequenceInfo:
slot_idx: Slot index for each sequence in the batch.
use_initial_states: Per-sequence boolean indicating if the initial states should be
used. If None, auto-computed as (input_pos > 0).
logits_gather_indices: Gather indices for the logits before/after the LM head.
logits_gather_info: Info list containing [num_tokens_to_gather, gather_required].
_gather_idx: Gather indices for the overlap scheduler to reorder input tokens.
_mask_scatter_indices: Mask scatter indices for the overlap scheduler.
extra_args: Extra arguments to be stored in the interface.
@ -999,6 +1015,17 @@ class SequenceInfo:
use_initial_states = [i_p > 0 for i_p in self.input_pos]
self._store_arg("use_initial_states", use_initial_states)
# check for updated logits_gather_indices
if logits_gather_indices is None:
# default is to gather all logits
logits_gather_indices = list(range(self.total_num_tokens))
self._store_arg("logits_gather_indices", logits_gather_indices, force_copy=True)
# check for updated logits_gather_info
if logits_gather_info is None:
logits_gather_info = [len(logits_gather_indices), 1]
self._store_arg("logits_gather_info", logits_gather_info, force_copy=True)
### UPDATE OVERLAP SCHEDULER METADATA ######################################################
# check for updated _gather_idx
if _gather_idx is not None:
@ -1042,9 +1069,24 @@ class SequenceInfo:
ungathered_input_ids, gather_ids_device, mask_scatter_indices_device, input_ids_device
)
# TODO: remove once https://github.com/NVIDIA/TensorRT-LLM/issues/9878 is fixed and
# logits gather is enabled by default (only keep squeeze_logits)
@nvtx_range("ad_maybe_gather_logits")
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""Maybe gather the logits if logits have not been gathered yet."""
num_tokens = logits.shape[0] * logits.shape[1]
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
if gather_required and num_tokens_to_gather < num_tokens:
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
logits,
self._get_arg("logits_gather_indices"),
self._get_arg("logits_gather_info"),
)
return logits.squeeze(int(self.is_generate))
@nvtx_range("ad_unnest_sequences")
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
t_squeezed = t_nested.squeeze(int(self.is_generate))
return list(torch.split(t_squeezed, self.seq_len))

View File

@ -0,0 +1,41 @@
import torch
@torch.library.custom_op("auto_deploy::gather_logits_before_lm_head", mutates_args=())
def gather_logits_before_lm_head(
hidden_states: torch.Tensor,
logits_gather_indices: torch.Tensor, # long tensor
logits_gather_info: torch.Tensor, # int tensor
) -> torch.Tensor:
"""Gather hidden states using logits_gather_indices before LM head.
Args:
hidden_states: Hidden states tensor [b, 1, hidden] or [1, s_total, hidden]
logits_gather_indices: indices for gathering logits.
logits_gather_info: info for gathering logits.
Returns:
Gathered and flattened hidden states [num_gathered_tokens, hidden]
"""
# final shape is [total_tokens, hidden] from [b, 1, hidden] or [1, total_tokens, hidden]
is_decode_only = hidden_states.shape[1] == 1
hidden_states = hidden_states.squeeze(int(is_decode_only))
# info object
num_tokens_to_gather, gather_required = logits_gather_info.tolist()
if gather_required:
out = hidden_states.index_select(0, logits_gather_indices[:num_tokens_to_gather])
else:
out = hidden_states.clone(memory_format=torch.contiguous_format)
return out.unsqueeze(int(is_decode_only))
@gather_logits_before_lm_head.register_fake
def gather_logits_before_lm_head_fake(
hidden_states: torch.Tensor,
logits_gather_indices: torch.Tensor,
logits_gather_info: torch.Tensor,
) -> torch.Tensor:
# NOTE: shape is not correct in fake mode
# see https://github.com/NVIDIA/TensorRT-LLM/issues/9878
return torch.empty_like(hidden_states)

View File

@ -10,6 +10,7 @@
# limitations under the License.
import copy
import types
from collections import defaultdict
from dataclasses import dataclass
from types import SimpleNamespace
@ -319,7 +320,7 @@ class ADEngine(ModelEngine):
return self.cache_seq_interface.device
@classmethod
def build_from_config(cls, ad_config: LlmArgs):
def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None):
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
max_batch_size = ad_config.max_batch_size
@ -360,6 +361,7 @@ class ADEngine(ModelEngine):
seq_info,
device,
ad_config=ad_config,
mapping=mapping,
reporting_info=reporting_info,
)
@ -370,6 +372,7 @@ class ADEngine(ModelEngine):
seq_info: SequenceInfo,
device: DeviceLikeType,
ad_config: Optional[LlmArgs] = None,
mapping: Optional[Mapping] = None,
reporting_info: ReportingInfo = ReportingInfo(),
) -> None:
"""Initialize the engine with model and sequence information."""
@ -439,13 +442,20 @@ class ADEngine(ModelEngine):
# keep a reference for one dummy request around
self.padding_dummy_request: Optional[LlmRequest] = None
# Reuse _execute_logit_post_processors from PyTorchModelEngine
self.mapping = mapping
self._execute_logit_post_processors = types.MethodType(
PyTorchModelEngine._execute_logit_post_processors, self
)
@nvtx_range("ad_prepare_inputs")
def _prepare_inputs(
self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tokens: Optional[torch.Tensor] = None,
) -> List[bool]:
gather_context_logits: bool = False,
) -> None:
"""Prepare inputs for AD Model from scheduled requests."""
# cache manager
kv_cache_manager = resource_manager.get_resource_manager(
@ -467,7 +477,6 @@ class ADEngine(ModelEngine):
input_pos: List[int] = []
seq_len: List[int] = []
cu_seqlen: List[int] = [0]
last_logit_only: List[bool] = []
cache_loc: List[int] = []
pages_per_seq: List[int] = []
cu_num_pages: List[int] = [0]
@ -481,6 +490,9 @@ class ADEngine(ModelEngine):
mask_scatter_indices: List[int] = []
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
# gather indices for logits
logits_gather_indices: List[int] = []
page_size = self.cache_seq_interface.info.page_size
dummy_token = -1
num_ctx_requests = len(context_requests)
@ -505,7 +517,11 @@ class ADEngine(ModelEngine):
cu_seqlen.append(cu_seqlen[-1] + seq_len[-1])
request.py_batch_idx = request.seq_slot
last_logit_only.append(True)
if gather_context_logits:
logits_gather_indices.extend(range(cu_seqlen[-2], cu_seqlen[-1]))
else:
logits_gather_indices.append(cu_seqlen[-1] - 1)
# get cache indices and truncate the number of blocks according to end_compute
cache_indices = kv_cache_manager.get_cache_indices(request)
@ -600,11 +616,13 @@ class ADEngine(ModelEngine):
request.py_batch_idx = request.seq_slot
slot_idx.append(request.seq_slot)
use_initial_states.append(input_pos[-1] > 0)
last_logit_only.append(False)
seq_len.append(len(input_ids[-1]))
cu_seqlen.append(cu_seqlen[-1] + seq_len[-1])
# for generate requests, we always keep all logits (target logits + draft logits)
logits_gather_indices.extend(range(cu_seqlen[-2], cu_seqlen[-1]))
if use_overlap:
mask_scatter_indices.extend(list(range(cu_seqlen[-2], cu_seqlen[-1])))
@ -618,6 +636,14 @@ class ADEngine(ModelEngine):
position_ids.append(list(range(input_pos[-1], seq_len_with_cache[-1])))
# check for logits_gather_info
# we only need to gather in the following situation:
# 1. there are context requests and
# 2. we are not gathering context logits
# In other cases (decode-only) or when we keep all logits, we do not need to gather.
gather_required = len(context_requests) > 0 and not gather_context_logits
logits_gather_info = [len(logits_gather_indices), int(gather_required)]
# update the sequence info object now
self.cache_seq_interface.info.nest_sequences(
input_ids,
@ -632,6 +658,8 @@ class ADEngine(ModelEngine):
last_page_len=last_page_len,
slot_idx=slot_idx,
use_initial_states=use_initial_states,
logits_gather_indices=logits_gather_indices,
logits_gather_info=logits_gather_info,
_gather_idx=None if new_tokens is None else flat_gather_indices,
_mask_scatter_indices=None if new_tokens is None else mask_scatter_indices,
**extra_args,
@ -644,18 +672,15 @@ class ADEngine(ModelEngine):
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
# TODO: handle extend requests and draft requests for specdec
self.iter_states["num_generation_tokens"] = num_generation_tokens
return last_logit_only
@nvtx_range("ad_compute_logits")
def _compute_logits(self) -> List[torch.Tensor]:
# run the model
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
logits = self.cache_seq_interface.info.maybe_gather_and_squeeze_logits(logits)
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
logits = logits.float()
# return a list of tensors
return self.cache_seq_interface.info.unnest_sequences(logits)
return logits.float()
def get_max_num_sequences(self) -> int:
"""Maximum number of sequences supported by the engine."""
@ -674,19 +699,18 @@ class ADEngine(ModelEngine):
"""Run forward from scheduled requests; main entrypoint that gets called by the executor."""
# convert requests and store in sequence info object
new_tokens = getattr(new_tensors_device, "new_tokens", None)
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens)
self._prepare_inputs(
scheduled_requests, resource_manager, new_tokens, gather_context_logits
)
self.iter_counter += 1
# compute all logits
logits = self._compute_logits()
outputs = {
"logits": self._compute_logits(),
}
if self.mapping is not None:
self._execute_logit_post_processors(scheduled_requests, outputs)
# gather+cat logits
logits_flat = torch.cat(
[ls_one_seq[-last_only:] for ls_one_seq, last_only in zip(logits, last_logit_only)],
dim=0,
)
return {"logits": logits_flat}
return outputs
def create_draft_model_engine_maybe(
@ -825,7 +849,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
)
# initialize model engine
engine = ADEngine.build_from_config(ad_config=ad_config)
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
spec_config = ad_config.speculative_config
if spec_config is not None and not spec_config.spec_dec_mode.is_draft_target():

View File

@ -137,7 +137,7 @@ class DemoEngine(ADEngine):
context_logits: Optional[List[torch.Tensor]] = None
def _generate_single_step(idx: int):
logits = self._compute_logits()
logits = sequence_info.unnest_sequences(self._compute_logits())
logits_last = torch.stack([l_one_seq[-1] for l_one_seq in logits]).float().unsqueeze(1)
token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1]

View File

@ -12,11 +12,12 @@ from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
import torch.nn as nn
from pydantic import BaseModel, Field
from torch.fx import GraphModule
from torch.fx import GraphModule, Node
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..utils._graph import (
add_graph_input,
canonicalize_graph,
lift_to_meta,
named_graphmodules,
@ -505,6 +506,19 @@ class BaseTransform(ABC):
f"Transform {self.get_transform_key()} only supports `run_per_gm=True`."
)
def _add_or_retrieve_input(
self, gm: GraphModule, cm: CachedSequenceInterface, name: str
) -> Node:
"""Add or retrieve an input node from the graph."""
input_nodes = gm.graph.find_nodes(op="placeholder", target=name)
if len(input_nodes) == 0:
cm.info.activate_arg(name)
return add_graph_input(gm, name)
elif len(input_nodes) == 1:
return input_nodes[0]
else:
raise ValueError(f"Expected exactly one input node for {name=}, got {input_nodes=}")
class TransformRegistry:
"""A registry for all transforms."""

View File

@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transform to gather hidden states before LM head using logits_gather_mask.
This moves the gather operation into the model graph before the LM head,
enabling CUDA graph capture and reducing computation by only computing
logits for the tokens that are actually needed.
"""
from typing import Tuple
import torch
from torch.fx import GraphModule
from ...utils.node_utils import is_linear_op, is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@TransformRegistry.register("gather_logits_before_lm_head")
class GatherLogitsBeforeLmHeadTransform(BaseTransform):
"""Transform to gather hidden states before LM head using logits_gather_mask.
This transform inserts a gather operation before the LM head linear layer
to select only the hidden states that need logits computed. The output is always
[b, hidden_size] in decode-only for CUDA graph compatibility.
Benefits:
- Reduces computation by only computing logits for needed tokens
- Eliminates Python loop overhead
- Enables CUDA graph capture of the gather
- Moves gather into the graph for better optimization
"""
def _apply(
self,
gm: GraphModule,
cm,
factory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
self._log_info("Applying GatherLogitsBeforeLmHead transform...")
# assume lm head node is the input to the output node
lm_head_node = gm.graph.find_nodes(op="output")[0].all_input_nodes[0]
if is_op(lm_head_node, torch.ops.aten.to):
lm_head_node = lm_head_node.all_input_nodes[0]
if is_linear_op(lm_head_node):
node_to_gather = lm_head_node.all_input_nodes[0]
self._log_info(f"Found LM head node: {lm_head_node.name}")
else:
node_to_gather = lm_head_node
self._log_info("lm_head node is not linear, using it as the node to gather")
# Add logits_gather_mask as input in the graph and the sequence info interface
logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "logits_gather_indices")
logits_gather_info_node = self._add_or_retrieve_input(gm, cm, "logits_gather_info")
with gm.graph.inserting_after(node_to_gather):
gathered_node = gm.graph.call_function(
torch.ops.auto_deploy.gather_logits_before_lm_head.default,
args=(node_to_gather, logits_gather_indices_node, logits_gather_info_node),
)
node_to_gather.replace_all_uses_with(gathered_node)
gathered_node.replace_input_with(gathered_node, node_to_gather)
return gm, TransformInfo(skipped=False, num_matches=1)

View File

@ -54,19 +54,6 @@ class InsertCachedAttention(BaseTransform):
def attn_descriptor(self) -> Type[AttentionDescriptor]:
return AttentionRegistry.get(self.config.backend)
def _add_or_retrieve_input(
self, gm: GraphModule, cm: CachedSequenceInterface, name: str
) -> Node:
"""Add or retrieve an input node from the graph."""
input_nodes = gm.graph.find_nodes(op="placeholder", target=name)
if len(input_nodes) == 0:
cm.info.activate_arg(name)
return add_graph_input(gm, name)
elif len(input_nodes) == 1:
return input_nodes[0]
else:
raise ValueError(f"Expected exactly one input node for {name=}, got {input_nodes=}")
def _process_metadata_std(self, gm: GraphModule, cm: CachedSequenceInterface) -> List[Node]:
"""Process the standard metadata nodes."""
return [

View File

@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_page_size: int):
sequence_info.reset()
sequence_info.nest_sequences(input_ids)
logits = engine._compute_logits()
logits = torch.stack(logits)
assert logits is not None, "Logits are None"
mock_input = None
@ -105,7 +104,6 @@ def test_demo_engine_sampling(attn_page_size: int):
sequence_info.reset()
sequence_info.nest_sequences(input_ids)
logits = engine._compute_logits()
logits = torch.stack(logits)
vocab_size = logits.size(-1)
sampling_params = SamplingParams(top_k=5, temperature=1.0)
@ -199,7 +197,7 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
# No-chunk: whole prompt in one request
req_full = _DummyRequest(tokens=tokens, begin=0, size=len(tokens), seq_slot=0)
scheduled_full = SimpleNamespace(context_requests=[req_full], generation_requests=[])
logits_full = engine.forward(scheduled_full, resource_manager)["logits"]
logits_full_last = engine.forward(scheduled_full, resource_manager)["logits"][-1]
# Chunked: split into two context chunks
split = len(tokens) // 2
@ -211,7 +209,6 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
# Run first chunk (ignored output), then compare second chunk logits to full
_ = engine.forward(scheduled_part1, resource_manager)
logits_chunked_last = engine.forward(scheduled_part2, resource_manager)["logits"]
logits_chunked_last = engine.forward(scheduled_part2, resource_manager)["logits"][-1]
assert logits_full.shape == logits_chunked_last.shape
assert torch.allclose(logits_full, logits_chunked_last, atol=1e-5)
torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5)

View File

@ -0,0 +1,326 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for gather_logits_before_lm_head transform."""
import pytest
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Dim
from torch.fx import GraphModule
# Import to register custom op
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
class SimpleLMHeadModel(torch.nn.Module):
"""Simple model with LM head for testing."""
def __init__(self, hidden_size: int = 128, vocab_size: int = 1000):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size, device="cuda", dtype=torch.float16)
self.lm_head = torch.nn.Linear(hidden_size, vocab_size, device="cuda", dtype=torch.float16)
def forward(self, hidden_states, logit_gather_ids=None, seq_len=None):
# Simulate transformer output
hidden_states = self.linear1(hidden_states)
# LM head
logits = self.lm_head(hidden_states)
return logits
class TestGatherLogitsBeforeLmHeadOp:
"""Test the custom op directly."""
@pytest.mark.parametrize("batch_size", [1, 4, 8])
def test_generate_format(self, batch_size):
"""Test gather op with generate format input [batch, 1, hidden]."""
hidden_size = 128
hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16)
# Create gather info: num_tokens_to_gather=batch_size, gather_required=0 (False)
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states, logits_gather_indices, logits_gather_info
)
# Should return [batch, 1, hidden] for generate format (3D shape preserved)
assert output.shape == (batch_size, 1, hidden_size)
assert output.dtype == hidden_states.dtype
assert output.device == hidden_states.device
torch.testing.assert_close(output, hidden_states)
@pytest.mark.parametrize("total_tokens", [10, 50, 100])
def test_packed_format(self, total_tokens):
"""Test gather op with packed format input [1, total_tokens, hidden]."""
hidden_size = 128
max_batch_size = 8
hidden_states = torch.randn(
1, total_tokens, hidden_size, device="cuda", dtype=torch.float16
)
# Create gather indices: gather tokens at indices [0, 5, 10, ...] up to max_batch_size
num_gather = min(total_tokens, max_batch_size)
gather_indices = torch.arange(0, num_gather, dtype=torch.long, device="cuda")
# Create gather info: num_tokens_to_gather=num_gather, gather_required=1 (True)
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states, gather_indices, logits_gather_info
)
# Should return [1, num_gather, hidden] for packed format (3D shape preserved)
assert output.shape == (1, num_gather, hidden_size)
assert output.dtype == hidden_states.dtype
assert output.device == hidden_states.device
# Verify gathered values match expected indices
expected = hidden_states[:, gather_indices, :]
torch.testing.assert_close(output, expected)
def test_fake_implementation_generate_format(self):
"""Test fake implementation for generate format."""
batch_size = 4
hidden_size = 128
hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16)
# Create gather info
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
# Use fake implementation directly
with FakeTensorMode() as mode:
hidden_states_fake = mode.from_tensor(hidden_states)
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
)
# Should return [batch, 1, hidden_size] (fake returns empty_like which preserves 3D shape)
assert output.shape == (batch_size, 1, hidden_size)
assert output.dtype == hidden_states.dtype
assert output.device == hidden_states.device
def test_fake_implementation_packed_format(self):
"""Test fake implementation for packed format."""
total_tokens = 50
hidden_size = 128
num_gather = 8
hidden_states = torch.randn(
1, total_tokens, hidden_size, device="cuda", dtype=torch.float16
)
# Create gather info
logits_gather_indices = torch.arange(num_gather, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
# Use fake implementation directly
with FakeTensorMode() as mode:
hidden_states_fake = mode.from_tensor(hidden_states)
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
)
# The fake implementation returns empty_like which preserves input shape [1, total_tokens, hidden]
assert output.shape == (1, total_tokens, hidden_size)
assert output.dtype == hidden_states.dtype
assert output.device == hidden_states.device
class TestGatherLogitsBeforeLmHeadTransform:
"""Test the transform application."""
def _create_cached_sequence_interface(self, max_batch_size: int = 8, device: str = "cuda"):
"""Create a mock CachedSequenceInterface for testing."""
seq_info = SequenceInfo(
max_seq_len=64,
max_batch_size=max_batch_size,
max_num_tokens=1024,
)
seq_info.to(device)
return CachedSequenceInterface(seq_info, device=device)
def _check_gather_op_in_graph(self, gm: GraphModule) -> bool:
"""Check if gather_logits_before_lm_head op is in the graph."""
return any(
is_op(n, torch.ops.auto_deploy.gather_logits_before_lm_head) for n in gm.graph.nodes
)
@pytest.mark.parametrize("batch_size", [1, 4, 8])
def test_transform_generate_format(self, batch_size):
"""Test transform with generate format input."""
hidden_size = 128
vocab_size = 1000
model = SimpleLMHeadModel(hidden_size, vocab_size).cuda()
# Create input in generate format [batch, 1, hidden]
hidden_states = torch.randn(batch_size, 1, hidden_size, device="cuda", dtype=torch.float16)
max_batch_size = 8
logit_gather_ids = torch.zeros(max_batch_size, dtype=torch.long, device="cuda")
seq_len = torch.ones(batch_size, dtype=torch.long, device="cuda")
# Export model
# When batch_size=1, torch.export specializes it to a constant, so we skip dynamic shapes
# For batch_size > 1, we use dynamic shapes to test the transform with varying batch sizes
if batch_size == 1:
dynamic_shapes = None
else:
# dynamic_shapes should be a tuple matching the number of positional args
dynamic_shapes = (
{0: Dim("batch_size", min=1, max=max_batch_size)}, # hidden_states
None, # logit_gather_ids (static)
None, # seq_len (static)
)
gm = torch_export_to_gm(
model,
args=(hidden_states, logit_gather_ids, seq_len),
dynamic_shapes=dynamic_shapes,
clone=True,
)
# Apply transform
cm = self._create_cached_sequence_interface(max_batch_size)
transform_config = {
"gather_logits_before_lm_head": {
"stage": "post_load_fusion",
"max_batch_size": max_batch_size,
}
}
optimizer = InferenceOptimizer(None, transform_config)
gm_transformed = optimizer(cm, gm)
# Check that gather op was inserted
assert self._check_gather_op_in_graph(gm_transformed), "Gather op not found in graph"
# Test forward pass
# We must pass the new graph inputs manually since we are running the graph directly
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
output = gm_transformed(
hidden_states,
logit_gather_ids,
seq_len,
logits_gather_indices=logits_gather_indices,
logits_gather_info=logits_gather_info,
)
# Output should be [batch_size, 1, vocab_size] since gather now returns 3D
assert output.shape == (batch_size, 1, vocab_size)
@pytest.mark.parametrize("total_tokens", [10, 50])
def test_transform_packed_format(self, total_tokens):
"""Test transform with packed format input."""
hidden_size = 128
vocab_size = 1000
max_batch_size = 8
model = SimpleLMHeadModel(hidden_size, vocab_size).cuda()
# Create input in packed format [1, total_tokens, hidden]
hidden_states = torch.randn(
1, total_tokens, hidden_size, device="cuda", dtype=torch.float16
)
logit_gather_ids = torch.arange(
0, min(total_tokens, max_batch_size), dtype=torch.long, device="cuda"
)
# Pad to max_batch_size
logit_gather_ids_padded = torch.zeros(max_batch_size, dtype=torch.long, device="cuda")
logit_gather_ids_padded[: len(logit_gather_ids)] = logit_gather_ids
seq_len = torch.ones(max_batch_size, dtype=torch.long, device="cuda")
seq_len[: len(logit_gather_ids)] = torch.ones(
len(logit_gather_ids), dtype=torch.long, device="cuda"
)
# Export model
gm = torch_export_to_gm(
model,
args=(hidden_states, logit_gather_ids_padded, seq_len),
dynamic_shapes=None,
clone=True,
)
# Apply transform
cm = self._create_cached_sequence_interface(max_batch_size)
transform_config = {
"gather_logits_before_lm_head": {
"stage": "post_load_fusion",
"max_batch_size": max_batch_size,
}
}
optimizer = InferenceOptimizer(None, transform_config)
gm_transformed = optimizer(cm, gm)
# Check that gather op was inserted
assert self._check_gather_op_in_graph(gm_transformed), "Gather op not found in graph"
# Test forward pass
# We must pass the new graph inputs manually since we are running the graph directly
num_gather = len(logit_gather_ids)
logits_gather_indices = logit_gather_ids
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
output = gm_transformed(
hidden_states,
logit_gather_ids_padded,
seq_len,
logits_gather_indices=logits_gather_indices,
logits_gather_info=logits_gather_info,
)
# Output should be [1, num_gather, vocab_size] since gather now returns 3D
assert output.shape == (1, num_gather, vocab_size)
def test_transform_skips_when_disabled(self):
"""Test that transform skips when disabled."""
hidden_size = 128
vocab_size = 1000
model = SimpleLMHeadModel(hidden_size, vocab_size).cuda()
hidden_states = torch.randn(4, 1, hidden_size, device="cuda", dtype=torch.float16)
max_batch_size = 8
logit_gather_ids = torch.zeros(max_batch_size, dtype=torch.long, device="cuda")
seq_len = torch.ones(4, dtype=torch.long, device="cuda")
# Export model
gm = torch_export_to_gm(
model,
args=(hidden_states, logit_gather_ids, seq_len),
dynamic_shapes=None,
clone=True,
)
# Apply transform with disabled config
cm = self._create_cached_sequence_interface(max_batch_size)
transform_config = {
"gather_logits_before_lm_head": {
"stage": "post_load_fusion",
"enabled": False,
"max_batch_size": max_batch_size,
}
}
optimizer = InferenceOptimizer(None, transform_config)
gm_transformed = optimizer(cm, gm)
# Check that gather op was NOT inserted
assert not self._check_gather_op_in_graph(gm_transformed), (
"Gather op should not be in graph"
)