[#5048][enhance] AutoDeploy: Optimize prepare_inputs (#6634)

Optimize prepare_inputs routine in AutoDeploy, as part of the effort to reduce the performance gap compared to the default backend.
This PR includes two major fixes, and some other minor tweaks:
1. Avoid back and forth data copies
2. Optimize position ids update by separating the implementation for generation mode and context mode.

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
Gal Hubara-Agam 2025-08-10 13:55:04 +03:00 committed by GitHub
parent ee19ca5e58
commit 3c5aec19c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 175 additions and 69 deletions

View File

@ -162,7 +162,7 @@ class CapturedGraph(nn.Module):
# copy inputs to input buffers
for i, input_tensor in enumerate(args_batched):
self._input_buffers[i][: input_tensor.shape[0]] = input_tensor
self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True)
# run forward pass via graph
self.graphs[combined_shape].replay()

View File

@ -18,6 +18,8 @@ from torch._ops import OpOverloadPacket
from torch.export import Dim
from torch.fx import Node
from tensorrt_llm._utils import nvtx_range
@dataclass
class CacheConfig:
@ -87,11 +89,13 @@ class SequenceInfo:
# Similarly, if a batch is composed of generate-only requests,
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
max_num_tokens: Optional[int] = None
# device is the device on which the sequence info is stored.
device: str = "cuda"
## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
# input_ids MUST ALWAYS BE THE FIRST FIELD
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int))
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long))
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long))
seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
@ -110,24 +114,44 @@ class SequenceInfo:
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
# (max_batch_size, max_seq_len) input in trtllm runtime.
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
max_seq_len_adjusted = self.max_seq_len + 1
self.max_seq_len_adjusted = self.max_seq_len + 1
if self.max_num_tokens is None or self.max_num_tokens < 1:
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
# we use the provided max_num_tokens to calculate the number of pages
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted)
# Num pages can not be less than max_batch_size.
self._num_pages = max(
self.max_batch_size,
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
)
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
self.input_pos = torch.empty_like(self.seq_len)
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
self.pages_per_seq = torch.empty_like(self.seq_len)
# Ensure that the device is set before initializing the tensors.
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
# Consumers of the sequence info args require input_ids and position_ids to be truncated.
# We maintain a full version of the input_ids and position_ids to avoid overheads of tensor
# creation in every forward pass.
self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids_full = torch.zeros(
self.max_num_tokens, dtype=torch.long, device=self.device
)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device)
self.input_pos = torch.empty_like(self.seq_len, device=self.device)
# Allocated host tensors for sequence lengths and input positions so that
# position_ids calculation can be done on host.
self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int)
self.input_pos_host = torch.empty_like(self.seq_len_host)
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device)
self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device)
self.previous_batch_indices_cuda = torch.empty(
self.max_num_tokens, dtype=torch.long, device=self.device
)
assert self.num_pages >= self.max_batch_size, (
"num_pages must be greater than max_batch_size"
)
@ -140,13 +164,12 @@ class SequenceInfo:
# indicator if extra args are activated that are needed for cached attention backends
self._is_cached_attn = False
# total number of tokens in the current batch
self.num_tokens: int = 0
# call reset once to initialize the tensors
self.reset()
@property
def device(self) -> torch.device:
return self.input_pos.device
@property
def args(self) -> Tuple[torch.Tensor, ...]:
args = []
@ -156,11 +179,14 @@ class SequenceInfo:
args.append(val)
if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn:
break
return tuple(args)
@property
def _num_uncached_attn_args(self) -> int:
"""Return the number of original graph arguments expected by the model."""
"""Return the number of original graph arguments expected by the model.
This is 2 because we have input_ids and position_ids as the original graph arguments.
"""
return 2
@property
@ -185,7 +211,7 @@ class SequenceInfo:
dynamic_shapes = ({}, {})
if self.max_batch_size > 1:
dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len)
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted)
# set up shape for position_ids (same as input_ids)
dynamic_shapes[1].update(dynamic_shapes[0])
# set up shape for extra args
@ -204,7 +230,7 @@ class SequenceInfo:
@property
def input_positions(self) -> List[int]:
return self.input_pos[: self.num_sequences].tolist()
return self.input_pos_host[: self.num_sequences].tolist()
@property
def is_generate(self) -> bool:
@ -334,14 +360,19 @@ class SequenceInfo:
"""
# reset input_pos
self.input_pos.zero_()
self.input_pos_host.zero_()
# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int))
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
# reset cache information
self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
self.pages_per_seq.fill_(1)
# let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens)
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
def set_example_sequence(self) -> None:
"""Set an example sequence useful for testing and export purposes."""
self.reset()
@ -352,7 +383,7 @@ class SequenceInfo:
dtype=torch.int,
device=self.device,
)
self.nest_sequences(input_ids)
self.nest_sequences(input_ids, allow_realloc=True)
# unflatten if we are not yet using cached+flattened attention
if not self._is_cached_attn:
@ -370,7 +401,7 @@ class SequenceInfo:
device=self.device,
)
self.pages_per_seq.fill_(seq_len // self.page_size)
self.nest_sequences(input_ids)
self.nest_sequences(input_ids, allow_realloc=True)
def set_generate_only_batch(self) -> None:
"""Set an example sequence for generate-only batch.
@ -379,32 +410,96 @@ class SequenceInfo:
mode. So we don't need to do anything mode-specific here.
"""
self.reset()
self.nest_sequences([[1]] * self.max_batch_size)
def _update_position_ids(self) -> None:
# set new position_ids as new tensor from input_pos and seq_len via torch.arange
position_ids_list = [
num
for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths)
for num in range(in_pos, in_pos + seq_len)
]
self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device)
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor:
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
if self.is_generate:
self.position_ids = self.position_ids.view(-1, 1)
return tensor.view(-1, 1, *tensor.shape[1:])
else:
self.position_ids = self.position_ids.view(1, -1)
return tensor.view(1, -1, *tensor.shape[1:])
def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
@nvtx_range("ad_update_position_ids")
def _update_position_ids(self, allow_realloc: bool = False) -> None:
# set new position_ids from input_pos and seq_len
# Make sure this is done on host to avoid host-device copies.
with nvtx_range("prepare_list"):
# Optimize for the common case where all seq_len values are 1 (generation mode)
if torch.all(self.seq_len_host == 1):
# Fast path: when all seq_len are 1, position_ids is just input_pos_host
position_ids_host = (
self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory()
)
else:
# General case - can probably be optimized too, but overall impact will be minor.
position_ids_list = []
for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host):
position_ids_list.extend(range(in_pos, in_pos + seq_len))
position_ids_host = torch.tensor(
position_ids_list, dtype=torch.long, pin_memory=True
)
with nvtx_range("copy_to_device"):
if allow_realloc:
# Create a new position_ids tensor on the device
self.position_ids = position_ids_host.to(self.device).clone()
else:
self.position_ids_full = self.position_ids_full.flatten()
self.position_ids_full[: self.num_tokens].copy_(
position_ids_host, non_blocking=True
)
with nvtx_range("maybe_reshape"):
self.position_ids = self.maybe_reshape_for_generate(
self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens]
)
@nvtx_range("ad_update_sequence_lengths")
def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None:
self._sequence_lengths = sequence_lengths
self.num_tokens = sum(self._sequence_lengths)
self.seq_len.zero_()
self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True)
self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True)
def update_input_ids_with_new_tokens(
self, new_tokens: torch.Tensor, previous_batch_indices: List[int]
) -> None:
"""Update the input_ids with new tokens.
This function will update the input_ids with new tokens and previous batch indices.
"""
# 1) flatten once
original_shape = self.input_ids.shape
flat = self.input_ids.flatten()
# copy indices to the GPU
host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True)
idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)]
idx.copy_(host_idx, non_blocking=True)
# sort them so that masked_scatter_ lines up correctly
idx, _ = idx.sort()
# gather the exact values you want to write
src = new_tokens[0, idx, 0]
# inplace fill every slot where flat == -1 with src, in order
flat.masked_scatter_(flat == -1, src)
# 4) reshape back
self.input_ids = flat.view(original_shape)
@nvtx_range("ad_nest_sequences")
def nest_sequences(
self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False
) -> None:
"""Create and store a flattened list of input_ids from the provided list of sequences.
When allow_realloc is True, the input_ids will be reallocated on the device.
This i/f will also update any relevant sequence information.
"""
# set new sequence lengths
seq_lens = [len(ids) for ids in input_ids]
self.seq_len.zero_()
self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
self._update_sequence_lengths([len(ids) for ids in input_ids])
# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int
# set new input_ids as new tensor from flattened input_ids
@ -413,49 +508,57 @@ class SequenceInfo:
for lst in input_ids
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
]
self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device)
input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True)
# set derivative properties
self._sequence_lengths = seq_lens
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
if self.is_generate:
self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:])
if allow_realloc:
self.input_ids = input_ids_host.to(self.device).clone()
else:
self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:])
self.input_ids_full = self.input_ids_full.flatten()
self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True)
self.input_ids = self.maybe_reshape_for_generate(
self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens]
)
# update position_ids
self._update_position_ids()
self._update_position_ids(allow_realloc=allow_realloc)
@nvtx_range("ad_unnest_sequences")
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
return list(torch.split(t_squeezed, self.sequence_lengths))
@nvtx_range("ad_update_pos")
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
"""Update the starting position for each sequence in the cache.
If ``reset=True`, ``input_pos`` will be reset to zero before updating.
"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.tensor(seq_len, dtype=torch.int)
seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True)
bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size
if reset:
self.input_pos[:bs] = seq_len.to(self.device)
self.input_pos_host[:bs].copy_(seq_len, non_blocking=True)
else:
self.input_pos[:bs] += seq_len.to(self.device)
self.input_pos_host[:bs] += seq_len
# update position_ids
self._update_position_ids()
self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True)
@nvtx_range("ad_assign_cache_loc")
def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
"""Set the cache location and pages_per_seq tensors from page assignments."""
cache_loc_flat = torch.tensor(
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
[p_idx for pages in page_assignments for p_idx in pages],
dtype=torch.int,
pin_memory=True,
)
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
pages_per_seq = torch.tensor(
[len(p) for p in page_assignments], dtype=torch.int, pin_memory=True
)
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)

View File

@ -94,20 +94,21 @@ class ADEngine(ModelEngine):
f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}, {max_beam_width=}"
)
# update device to contain the current default device if it's in cuda
device = torch.device(ad_config.device)
if device.type == "cuda" and device.index is None:
device = torch.device(f"cuda:{torch.cuda.current_device()}")
device = str(device)
# initialize seq info object
seq_info = SequenceInfo(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
page_size=attn_page_size,
max_num_tokens=max_num_tokens,
device=device,
)
# update device to contain the current default device if it's in cuda
device = torch.device(ad_config.device)
if device.type == "cuda" and device.index is None:
device = torch.device(f"cuda:{torch.cuda.current_device()}")
device = str(device)
# construct inference optimizer
build_and_optimize = InferenceOptimizer(
factory=ad_config.create_factory(), ad_config=ad_config
@ -170,16 +171,12 @@ class ADEngine(ModelEngine):
context_requests = scheduled_requests.context_requests
gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens]
# new_tokens is a tensor on the device, we need to convert it to a list of lists.
# can we avoid this additional gpu->cpu transfer?
new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None
# info to be extracted
input_ids: List[List[int]] = []
input_pos: List[int] = []
last_logit_only: List[bool] = []
page_assignments: List[List[int]] = []
previous_batch_indices: List[int] = []
# look at context requests first
for request in context_requests:
# store input ids and pos of first token in sequence
@ -193,11 +190,13 @@ class ADEngine(ModelEngine):
# TODO: we should also handle extend requests (for speculative decoding) here
for request in gen_requests:
# new_tokens are provided when the overlap scheduler is enabled.
if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None:
if new_tokens is None or request.is_dummy or request.py_batch_idx is None:
input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)])
input_pos.append(request.max_beam_num_tokens - 1)
else:
input_ids.append([new_tokens_list[request.py_batch_idx]])
# insert a dummy token to indicate the new tokens
input_ids.append([-1])
previous_batch_indices.append(request.py_batch_idx)
input_pos.append(request.max_beam_num_tokens)
request.py_batch_idx = request.seq_slot
@ -213,11 +212,15 @@ class ADEngine(ModelEngine):
# update the sequence info object now
si = self.cache_seq_interface.info
si.nest_sequences(input_ids)
si.update_pos(input_pos, reset=True)
si.assign_cache_loc(page_assignments)
si.nest_sequences(input_ids)
if new_tokens is not None:
si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices)
return last_logit_only
@nvtx_range("ad_compute_logits")
def _compute_logits(self) -> List[torch.Tensor]:
# run the model
logits: torch.Tensor = self.model(*self.cache_seq_interface.args)[0]
@ -234,13 +237,13 @@ class ADEngine(ModelEngine):
self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tokens_device: Optional[torch.Tensor] = None,
new_tensors_device: Optional[torch.Tensor] = None,
gather_context_logits: bool = False,
cache_indirection_buffer: Optional[torch.Tensor] = None,
):
"""Run forward from scheduled requests; main entrypoint that gets called by the executor."""
# convert requests and store in sequence info object
new_tokens = getattr(new_tokens_device, "new_tokens", None)
new_tokens = getattr(new_tensors_device, "new_tokens", None)
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens)
# compute all logits

View File

@ -187,7 +187,7 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
# Helper function to call the model with proper sequence nesting
def _call_and_unnest(x):
# Use nest_sequences to properly set input_ids and automatically update position_ids
cm.info.nest_sequences(x)
cm.info.nest_sequences(x, allow_realloc=True)
# Use the cm.args as is - it already contains the correct position_ids
y = gm(*cm.args)