mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Optimize prepare_inputs routine in AutoDeploy, as part of the effort to reduce the performance gap compared to the default backend. This PR includes two major fixes, and some other minor tweaks: 1. Avoid back and forth data copies 2. Optimize position ids update by separating the implementation for generation mode and context mode. Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
parent
ee19ca5e58
commit
3c5aec19c2
@ -162,7 +162,7 @@ class CapturedGraph(nn.Module):
|
||||
|
||||
# copy inputs to input buffers
|
||||
for i, input_tensor in enumerate(args_batched):
|
||||
self._input_buffers[i][: input_tensor.shape[0]] = input_tensor
|
||||
self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True)
|
||||
|
||||
# run forward pass via graph
|
||||
self.graphs[combined_shape].replay()
|
||||
|
||||
@ -18,6 +18,8 @@ from torch._ops import OpOverloadPacket
|
||||
from torch.export import Dim
|
||||
from torch.fx import Node
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
@ -87,11 +89,13 @@ class SequenceInfo:
|
||||
# Similarly, if a batch is composed of generate-only requests,
|
||||
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
|
||||
max_num_tokens: Optional[int] = None
|
||||
# device is the device on which the sequence info is stored.
|
||||
device: str = "cuda"
|
||||
|
||||
## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
|
||||
# input_ids MUST ALWAYS BE THE FIRST FIELD
|
||||
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int))
|
||||
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long))
|
||||
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
|
||||
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long))
|
||||
|
||||
seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
|
||||
input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
|
||||
@ -110,24 +114,44 @@ class SequenceInfo:
|
||||
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
|
||||
# (max_batch_size, max_seq_len) input in trtllm runtime.
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
|
||||
max_seq_len_adjusted = self.max_seq_len + 1
|
||||
self.max_seq_len_adjusted = self.max_seq_len + 1
|
||||
|
||||
if self.max_num_tokens is None or self.max_num_tokens < 1:
|
||||
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
|
||||
self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted
|
||||
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
|
||||
# we use the provided max_num_tokens to calculate the number of pages
|
||||
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
|
||||
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted)
|
||||
# Num pages can not be less than max_batch_size.
|
||||
self._num_pages = max(
|
||||
self.max_batch_size,
|
||||
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
|
||||
)
|
||||
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
|
||||
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
|
||||
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
|
||||
self.input_pos = torch.empty_like(self.seq_len)
|
||||
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
|
||||
self.pages_per_seq = torch.empty_like(self.seq_len)
|
||||
# Ensure that the device is set before initializing the tensors.
|
||||
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
|
||||
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
|
||||
|
||||
# Consumers of the sequence info args require input_ids and position_ids to be truncated.
|
||||
# We maintain a full version of the input_ids and position_ids to avoid overheads of tensor
|
||||
# creation in every forward pass.
|
||||
self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
|
||||
self.position_ids_full = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.long, device=self.device
|
||||
)
|
||||
|
||||
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device)
|
||||
self.input_pos = torch.empty_like(self.seq_len, device=self.device)
|
||||
|
||||
# Allocated host tensors for sequence lengths and input positions so that
|
||||
# position_ids calculation can be done on host.
|
||||
self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int)
|
||||
self.input_pos_host = torch.empty_like(self.seq_len_host)
|
||||
|
||||
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device)
|
||||
self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device)
|
||||
|
||||
self.previous_batch_indices_cuda = torch.empty(
|
||||
self.max_num_tokens, dtype=torch.long, device=self.device
|
||||
)
|
||||
assert self.num_pages >= self.max_batch_size, (
|
||||
"num_pages must be greater than max_batch_size"
|
||||
)
|
||||
@ -140,13 +164,12 @@ class SequenceInfo:
|
||||
# indicator if extra args are activated that are needed for cached attention backends
|
||||
self._is_cached_attn = False
|
||||
|
||||
# total number of tokens in the current batch
|
||||
self.num_tokens: int = 0
|
||||
|
||||
# call reset once to initialize the tensors
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.input_pos.device
|
||||
|
||||
@property
|
||||
def args(self) -> Tuple[torch.Tensor, ...]:
|
||||
args = []
|
||||
@ -156,11 +179,14 @@ class SequenceInfo:
|
||||
args.append(val)
|
||||
if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn:
|
||||
break
|
||||
|
||||
return tuple(args)
|
||||
|
||||
@property
|
||||
def _num_uncached_attn_args(self) -> int:
|
||||
"""Return the number of original graph arguments expected by the model."""
|
||||
"""Return the number of original graph arguments expected by the model.
|
||||
This is 2 because we have input_ids and position_ids as the original graph arguments.
|
||||
"""
|
||||
return 2
|
||||
|
||||
@property
|
||||
@ -185,7 +211,7 @@ class SequenceInfo:
|
||||
dynamic_shapes = ({}, {})
|
||||
if self.max_batch_size > 1:
|
||||
dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
|
||||
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len)
|
||||
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted)
|
||||
# set up shape for position_ids (same as input_ids)
|
||||
dynamic_shapes[1].update(dynamic_shapes[0])
|
||||
# set up shape for extra args
|
||||
@ -204,7 +230,7 @@ class SequenceInfo:
|
||||
|
||||
@property
|
||||
def input_positions(self) -> List[int]:
|
||||
return self.input_pos[: self.num_sequences].tolist()
|
||||
return self.input_pos_host[: self.num_sequences].tolist()
|
||||
|
||||
@property
|
||||
def is_generate(self) -> bool:
|
||||
@ -334,14 +360,19 @@ class SequenceInfo:
|
||||
"""
|
||||
# reset input_pos
|
||||
self.input_pos.zero_()
|
||||
self.input_pos_host.zero_()
|
||||
|
||||
# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
|
||||
self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int))
|
||||
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
|
||||
|
||||
# reset cache information
|
||||
self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
|
||||
self.pages_per_seq.fill_(1)
|
||||
|
||||
# let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens)
|
||||
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
|
||||
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
|
||||
|
||||
def set_example_sequence(self) -> None:
|
||||
"""Set an example sequence useful for testing and export purposes."""
|
||||
self.reset()
|
||||
@ -352,7 +383,7 @@ class SequenceInfo:
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
self.nest_sequences(input_ids)
|
||||
self.nest_sequences(input_ids, allow_realloc=True)
|
||||
|
||||
# unflatten if we are not yet using cached+flattened attention
|
||||
if not self._is_cached_attn:
|
||||
@ -370,7 +401,7 @@ class SequenceInfo:
|
||||
device=self.device,
|
||||
)
|
||||
self.pages_per_seq.fill_(seq_len // self.page_size)
|
||||
self.nest_sequences(input_ids)
|
||||
self.nest_sequences(input_ids, allow_realloc=True)
|
||||
|
||||
def set_generate_only_batch(self) -> None:
|
||||
"""Set an example sequence for generate-only batch.
|
||||
@ -379,32 +410,96 @@ class SequenceInfo:
|
||||
mode. So we don't need to do anything mode-specific here.
|
||||
"""
|
||||
self.reset()
|
||||
self.nest_sequences([[1]] * self.max_batch_size)
|
||||
|
||||
def _update_position_ids(self) -> None:
|
||||
# set new position_ids as new tensor from input_pos and seq_len via torch.arange
|
||||
position_ids_list = [
|
||||
num
|
||||
for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths)
|
||||
for num in range(in_pos, in_pos + seq_len)
|
||||
]
|
||||
self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device)
|
||||
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
|
||||
|
||||
def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
|
||||
if self.is_generate:
|
||||
self.position_ids = self.position_ids.view(-1, 1)
|
||||
return tensor.view(-1, 1, *tensor.shape[1:])
|
||||
else:
|
||||
self.position_ids = self.position_ids.view(1, -1)
|
||||
return tensor.view(1, -1, *tensor.shape[1:])
|
||||
|
||||
def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
|
||||
@nvtx_range("ad_update_position_ids")
|
||||
def _update_position_ids(self, allow_realloc: bool = False) -> None:
|
||||
# set new position_ids from input_pos and seq_len
|
||||
# Make sure this is done on host to avoid host-device copies.
|
||||
with nvtx_range("prepare_list"):
|
||||
# Optimize for the common case where all seq_len values are 1 (generation mode)
|
||||
if torch.all(self.seq_len_host == 1):
|
||||
# Fast path: when all seq_len are 1, position_ids is just input_pos_host
|
||||
position_ids_host = (
|
||||
self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory()
|
||||
)
|
||||
else:
|
||||
# General case - can probably be optimized too, but overall impact will be minor.
|
||||
position_ids_list = []
|
||||
for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host):
|
||||
position_ids_list.extend(range(in_pos, in_pos + seq_len))
|
||||
position_ids_host = torch.tensor(
|
||||
position_ids_list, dtype=torch.long, pin_memory=True
|
||||
)
|
||||
with nvtx_range("copy_to_device"):
|
||||
if allow_realloc:
|
||||
# Create a new position_ids tensor on the device
|
||||
self.position_ids = position_ids_host.to(self.device).clone()
|
||||
else:
|
||||
self.position_ids_full = self.position_ids_full.flatten()
|
||||
self.position_ids_full[: self.num_tokens].copy_(
|
||||
position_ids_host, non_blocking=True
|
||||
)
|
||||
with nvtx_range("maybe_reshape"):
|
||||
self.position_ids = self.maybe_reshape_for_generate(
|
||||
self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens]
|
||||
)
|
||||
|
||||
@nvtx_range("ad_update_sequence_lengths")
|
||||
def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None:
|
||||
self._sequence_lengths = sequence_lengths
|
||||
self.num_tokens = sum(self._sequence_lengths)
|
||||
self.seq_len.zero_()
|
||||
self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True)
|
||||
self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True)
|
||||
|
||||
def update_input_ids_with_new_tokens(
|
||||
self, new_tokens: torch.Tensor, previous_batch_indices: List[int]
|
||||
) -> None:
|
||||
"""Update the input_ids with new tokens.
|
||||
|
||||
This function will update the input_ids with new tokens and previous batch indices.
|
||||
"""
|
||||
# 1) flatten once
|
||||
original_shape = self.input_ids.shape
|
||||
flat = self.input_ids.flatten()
|
||||
|
||||
# copy indices to the GPU
|
||||
host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True)
|
||||
idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)]
|
||||
idx.copy_(host_idx, non_blocking=True)
|
||||
|
||||
# sort them so that masked_scatter_ lines up correctly
|
||||
idx, _ = idx.sort()
|
||||
|
||||
# gather the exact values you want to write
|
||||
src = new_tokens[0, idx, 0]
|
||||
|
||||
# in‐place fill every slot where flat == -1 with src, in order
|
||||
flat.masked_scatter_(flat == -1, src)
|
||||
|
||||
# 4) reshape back
|
||||
self.input_ids = flat.view(original_shape)
|
||||
|
||||
@nvtx_range("ad_nest_sequences")
|
||||
def nest_sequences(
|
||||
self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False
|
||||
) -> None:
|
||||
"""Create and store a flattened list of input_ids from the provided list of sequences.
|
||||
|
||||
When allow_realloc is True, the input_ids will be reallocated on the device.
|
||||
This i/f will also update any relevant sequence information.
|
||||
"""
|
||||
# set new sequence lengths
|
||||
seq_lens = [len(ids) for ids in input_ids]
|
||||
self.seq_len.zero_()
|
||||
self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
|
||||
self._update_sequence_lengths([len(ids) for ids in input_ids])
|
||||
|
||||
# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
|
||||
dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int
|
||||
# set new input_ids as new tensor from flattened input_ids
|
||||
@ -413,49 +508,57 @@ class SequenceInfo:
|
||||
for lst in input_ids
|
||||
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
|
||||
]
|
||||
self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device)
|
||||
input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True)
|
||||
|
||||
# set derivative properties
|
||||
self._sequence_lengths = seq_lens
|
||||
|
||||
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
|
||||
if self.is_generate:
|
||||
self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:])
|
||||
if allow_realloc:
|
||||
self.input_ids = input_ids_host.to(self.device).clone()
|
||||
else:
|
||||
self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:])
|
||||
self.input_ids_full = self.input_ids_full.flatten()
|
||||
self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True)
|
||||
|
||||
self.input_ids = self.maybe_reshape_for_generate(
|
||||
self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens]
|
||||
)
|
||||
# update position_ids
|
||||
self._update_position_ids()
|
||||
self._update_position_ids(allow_realloc=allow_realloc)
|
||||
|
||||
@nvtx_range("ad_unnest_sequences")
|
||||
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
|
||||
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
|
||||
return list(torch.split(t_squeezed, self.sequence_lengths))
|
||||
|
||||
@nvtx_range("ad_update_pos")
|
||||
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
|
||||
"""Update the starting position for each sequence in the cache.
|
||||
|
||||
If ``reset=True`, ``input_pos`` will be reset to zero before updating.
|
||||
"""
|
||||
if not isinstance(seq_len, torch.Tensor):
|
||||
seq_len = torch.tensor(seq_len, dtype=torch.int)
|
||||
seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True)
|
||||
bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size
|
||||
|
||||
if reset:
|
||||
self.input_pos[:bs] = seq_len.to(self.device)
|
||||
self.input_pos_host[:bs].copy_(seq_len, non_blocking=True)
|
||||
else:
|
||||
self.input_pos[:bs] += seq_len.to(self.device)
|
||||
self.input_pos_host[:bs] += seq_len
|
||||
|
||||
# update position_ids
|
||||
self._update_position_ids()
|
||||
self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True)
|
||||
|
||||
@nvtx_range("ad_assign_cache_loc")
|
||||
def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
|
||||
"""Set the cache location and pages_per_seq tensors from page assignments."""
|
||||
cache_loc_flat = torch.tensor(
|
||||
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
|
||||
[p_idx for pages in page_assignments for p_idx in pages],
|
||||
dtype=torch.int,
|
||||
pin_memory=True,
|
||||
)
|
||||
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
|
||||
|
||||
pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
|
||||
pages_per_seq = torch.tensor(
|
||||
[len(p) for p in page_assignments], dtype=torch.int, pin_memory=True
|
||||
)
|
||||
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
|
||||
|
||||
|
||||
|
||||
@ -94,20 +94,21 @@ class ADEngine(ModelEngine):
|
||||
f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}, {max_beam_width=}"
|
||||
)
|
||||
|
||||
# update device to contain the current default device if it's in cuda
|
||||
device = torch.device(ad_config.device)
|
||||
if device.type == "cuda" and device.index is None:
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
device = str(device)
|
||||
|
||||
# initialize seq info object
|
||||
seq_info = SequenceInfo(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
page_size=attn_page_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# update device to contain the current default device if it's in cuda
|
||||
device = torch.device(ad_config.device)
|
||||
if device.type == "cuda" and device.index is None:
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
device = str(device)
|
||||
|
||||
# construct inference optimizer
|
||||
build_and_optimize = InferenceOptimizer(
|
||||
factory=ad_config.create_factory(), ad_config=ad_config
|
||||
@ -170,16 +171,12 @@ class ADEngine(ModelEngine):
|
||||
context_requests = scheduled_requests.context_requests
|
||||
gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens]
|
||||
|
||||
# new_tokens is a tensor on the device, we need to convert it to a list of lists.
|
||||
# can we avoid this additional gpu->cpu transfer?
|
||||
new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None
|
||||
|
||||
# info to be extracted
|
||||
input_ids: List[List[int]] = []
|
||||
input_pos: List[int] = []
|
||||
last_logit_only: List[bool] = []
|
||||
page_assignments: List[List[int]] = []
|
||||
|
||||
previous_batch_indices: List[int] = []
|
||||
# look at context requests first
|
||||
for request in context_requests:
|
||||
# store input ids and pos of first token in sequence
|
||||
@ -193,11 +190,13 @@ class ADEngine(ModelEngine):
|
||||
# TODO: we should also handle extend requests (for speculative decoding) here
|
||||
for request in gen_requests:
|
||||
# new_tokens are provided when the overlap scheduler is enabled.
|
||||
if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None:
|
||||
if new_tokens is None or request.is_dummy or request.py_batch_idx is None:
|
||||
input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)])
|
||||
input_pos.append(request.max_beam_num_tokens - 1)
|
||||
else:
|
||||
input_ids.append([new_tokens_list[request.py_batch_idx]])
|
||||
# insert a dummy token to indicate the new tokens
|
||||
input_ids.append([-1])
|
||||
previous_batch_indices.append(request.py_batch_idx)
|
||||
input_pos.append(request.max_beam_num_tokens)
|
||||
|
||||
request.py_batch_idx = request.seq_slot
|
||||
@ -213,11 +212,15 @@ class ADEngine(ModelEngine):
|
||||
|
||||
# update the sequence info object now
|
||||
si = self.cache_seq_interface.info
|
||||
si.nest_sequences(input_ids)
|
||||
si.update_pos(input_pos, reset=True)
|
||||
si.assign_cache_loc(page_assignments)
|
||||
si.nest_sequences(input_ids)
|
||||
|
||||
if new_tokens is not None:
|
||||
si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices)
|
||||
return last_logit_only
|
||||
|
||||
@nvtx_range("ad_compute_logits")
|
||||
def _compute_logits(self) -> List[torch.Tensor]:
|
||||
# run the model
|
||||
logits: torch.Tensor = self.model(*self.cache_seq_interface.args)[0]
|
||||
@ -234,13 +237,13 @@ class ADEngine(ModelEngine):
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: ResourceManager,
|
||||
new_tokens_device: Optional[torch.Tensor] = None,
|
||||
new_tensors_device: Optional[torch.Tensor] = None,
|
||||
gather_context_logits: bool = False,
|
||||
cache_indirection_buffer: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Run forward from scheduled requests; main entrypoint that gets called by the executor."""
|
||||
# convert requests and store in sequence info object
|
||||
new_tokens = getattr(new_tokens_device, "new_tokens", None)
|
||||
new_tokens = getattr(new_tensors_device, "new_tokens", None)
|
||||
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens)
|
||||
|
||||
# compute all logits
|
||||
|
||||
@ -187,7 +187,7 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config):
|
||||
# Helper function to call the model with proper sequence nesting
|
||||
def _call_and_unnest(x):
|
||||
# Use nest_sequences to properly set input_ids and automatically update position_ids
|
||||
cm.info.nest_sequences(x)
|
||||
cm.info.nest_sequences(x, allow_realloc=True)
|
||||
|
||||
# Use the cm.args as is - it already contains the correct position_ids
|
||||
y = gm(*cm.args)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user