mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[fix] Fix add_dummy_requests for spec decoding cases (#4084)
* fix add_dummy_requests. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * add max_seq_len to eagle3 test and fix add_dummy_requests. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * fix prompt_len in add_dummy_requests. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * add prepare_resource condition in add_dummy_requests. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * add some description of token_nums to add_dummy_requests and fix token_nums in torch compile warmup. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * fix available_tokens. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --------- Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
5ce5b81281
commit
0cf0fce5d3
@ -430,9 +430,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
def get_cuda_graph_warmup_request(batch_size):
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks()
|
||||
|
||||
max_num_draft_tokens = self.spec_config.max_draft_tokens if self.is_spec_decode else 0
|
||||
|
||||
if available_blocks >= batch_size:
|
||||
result = ScheduledRequests()
|
||||
result.context_requests = []
|
||||
@ -441,30 +438,19 @@ class PyTorchModelEngine(ModelEngine):
|
||||
requests = kv_cache_manager.add_dummy_requests(
|
||||
list(range(batch_size - 1)),
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=max_num_draft_tokens,
|
||||
)
|
||||
available_blocks -= batch_size - 1
|
||||
available_tokens = available_blocks * kv_cache_manager.tokens_per_block
|
||||
# When we generate last token for the max_seq_len case,
|
||||
# we only need to store (max_seq_len - 1 - max_num_draft_tokens) tokens in the KV cache.
|
||||
# For the max_seq_len, some speculative decoding methods need extra kv tokens in kv cache
|
||||
# manager to support different kv lengths for the draft/target layers. So, we also
|
||||
# need to remove those extra tokens from the max_seq_len.
|
||||
token_num = max(
|
||||
1,
|
||||
min(
|
||||
available_tokens, self.max_seq_len -
|
||||
kv_cache_manager.num_extra_kv_tokens - 1 -
|
||||
max_num_draft_tokens),
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
)
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
self.max_draft_len)
|
||||
|
||||
# Add one dummy request with the maximum possible sequence length.
|
||||
# The sequence length is limited by both the max_seq_len and the number of available blocks.
|
||||
token_num = max(1, min(available_tokens, self.max_seq_len - 1))
|
||||
max_seq_len_request = kv_cache_manager.add_dummy_requests(
|
||||
request_ids=[batch_size - 1],
|
||||
token_nums=[token_num],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=max_num_draft_tokens,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
)[0]
|
||||
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
|
||||
# This batch contains both the longest request and the shortest requests,
|
||||
@ -487,13 +473,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_tokens_per_request / kv_cache_manager.tokens_per_block):
|
||||
# Should only need (at most) one more page per request.
|
||||
is_gen = num_tokens_per_request == 1
|
||||
max_num_draft_tokens = self.spec_config.max_draft_tokens if self.is_spec_decode and is_gen else 0
|
||||
|
||||
requests = kv_cache_manager.add_dummy_requests(
|
||||
list(range(batch_size)),
|
||||
[num_tokens_per_request] * batch_size,
|
||||
is_gen=is_gen,
|
||||
max_num_draft_tokens=max_num_draft_tokens)
|
||||
max_num_draft_tokens=self.max_draft_len)
|
||||
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(
|
||||
@ -577,7 +562,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
for num_tokens_per_request in [
|
||||
1,
|
||||
min(self.max_num_tokens // max(bs, 1),
|
||||
kv_cache_manager.max_seq_len - 1)
|
||||
self.max_seq_len - 1)
|
||||
]:
|
||||
with release_batch(
|
||||
get_torch_compile_warmup_request(
|
||||
@ -753,11 +738,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Set the dummy request ids starting at (uint64 max value - padding_size - 1) to avoid conflict with
|
||||
# active request IDs
|
||||
max_req_id = MAX_UINT64 - padding_size - 1
|
||||
max_num_draft_tokens = self.spec_config.max_draft_tokens if self.is_spec_decode else 0
|
||||
generation_requests = kv_cache_manager.add_dummy_requests(
|
||||
[max_req_id + i + 1 for i in range(padding_size)],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=max_num_draft_tokens)
|
||||
max_num_draft_tokens=self.max_draft_len)
|
||||
scheduled_requests.generation_requests.extend(generation_requests)
|
||||
return generation_requests
|
||||
|
||||
|
||||
@ -272,6 +272,12 @@ class KVCacheManager(BaseResourceManager):
|
||||
def add_dummy_requests(
|
||||
self,
|
||||
request_ids: List[int],
|
||||
# Note that token_nums should be past_kv_len + input_len (without
|
||||
# spec decoding). The draft tokens will be added in this function,
|
||||
# so we don't need to take care of it in the caller. When preparing
|
||||
# token_nums, we should not take the draft tokens into account, so
|
||||
# don't use the kv_cache_manager.max_seq_len, which includes both
|
||||
# extra tokens and draft tokens.
|
||||
token_nums: Optional[List[int]] = None,
|
||||
is_gen: bool = False,
|
||||
prepare_resource: bool = True,
|
||||
@ -281,8 +287,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
requests = []
|
||||
for i, req_id in enumerate(request_ids):
|
||||
sampling_params = SamplingParams()
|
||||
token_num = token_nums[
|
||||
i] if token_nums is not None else 1 + max_num_draft_tokens
|
||||
token_num = token_nums[i] if token_nums is not None else 1
|
||||
encoder_input_tokens = [
|
||||
1
|
||||
] * token_num if self.impl.cross_kv else None
|
||||
@ -297,12 +302,17 @@ class KVCacheManager(BaseResourceManager):
|
||||
req.paged_kv_block_ids = []
|
||||
if prepare_resource:
|
||||
self.impl.add_sequence(req_id, token_num, beam_width, req)
|
||||
for _ in range(self.num_extra_kv_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
if is_gen:
|
||||
req.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
req.prompt_len = token_num - 1 + max_num_draft_tokens
|
||||
req.prompt_len = token_num - 1
|
||||
req.py_prompt_len = req.prompt_len
|
||||
if max_num_draft_tokens > 0:
|
||||
req.py_draft_tokens = [0] * max_num_draft_tokens
|
||||
req.py_draft_tokens = [1] * max_num_draft_tokens
|
||||
if prepare_resource:
|
||||
for _ in range(max_num_draft_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
requests.append(req)
|
||||
return requests
|
||||
|
||||
@ -409,6 +419,10 @@ class KVCacheManager(BaseResourceManager):
|
||||
def get_num_kv_blocks(self, num_tokens: int) -> int:
|
||||
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
|
||||
|
||||
def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int:
|
||||
return (self.get_num_free_blocks() * self.tokens_per_block -
|
||||
self.num_extra_kv_tokens - max_num_draft_tokens)
|
||||
|
||||
def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]:
|
||||
result = self.impl.get_primary_pool_data(layer_idx)
|
||||
return result.reshape(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user