mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5441438][fix] Set correct draft length for the cuda graph dummy request (#6701)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
parent
ead89a0e40
commit
b4fcd5f592
@ -453,6 +453,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
else:
|
||||
self.cache_indirection_attention = None
|
||||
|
||||
@property
|
||||
def runtime_draft_len(self):
|
||||
return self.max_draft_len if self.enable_spec_decode else 0
|
||||
|
||||
def set_lora_model_config(self, lora_target_modules: list[str],
|
||||
trtllm_modules_to_hf_modules: dict[str, str]):
|
||||
self.lora_model_config = LoraModelConfig(
|
||||
@ -573,7 +577,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
list(range(batch_size)), [num_tokens_per_request] *
|
||||
batch_size if not is_gen else None,
|
||||
is_gen=is_gen,
|
||||
max_num_draft_tokens=self.max_draft_len)
|
||||
max_num_draft_tokens=self.runtime_draft_len)
|
||||
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(
|
||||
@ -592,7 +596,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
def get_autotune_warmup_request():
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
self.max_draft_len)
|
||||
self.runtime_draft_len)
|
||||
num_tokens_per_request = min(
|
||||
min(available_tokens, self.max_seq_len - 1),
|
||||
self.max_num_tokens)
|
||||
@ -626,14 +630,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
request_ids=list(range(full_len_request_num)),
|
||||
token_nums=[num_tokens_per_request] * full_len_request_num,
|
||||
is_gen=False,
|
||||
max_num_draft_tokens=self.max_draft_len)
|
||||
max_num_draft_tokens=self.runtime_draft_len)
|
||||
|
||||
if remaining_tokens > 0:
|
||||
final_request = kv_cache_manager.add_dummy_requests(
|
||||
request_ids=[full_len_request_num],
|
||||
token_nums=[remaining_tokens],
|
||||
is_gen=False,
|
||||
max_num_draft_tokens=self.max_draft_len)
|
||||
max_num_draft_tokens=self.runtime_draft_len)
|
||||
|
||||
requests += final_request
|
||||
|
||||
@ -680,7 +684,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Disable cuda graph capture here so that we can properly capture it later
|
||||
with self.no_cuda_graph():
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
self.max_draft_len)
|
||||
self.runtime_draft_len)
|
||||
warmup_batch_size = [1, self.batch_size // 2]
|
||||
if self.batch_size < 2:
|
||||
warmup_batch_size = [1]
|
||||
@ -898,7 +902,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
|
||||
cuda_graph_dummy_request_ids,
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
max_num_draft_tokens=self.runtime_draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width)[0]
|
||||
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
|
||||
@ -1332,7 +1336,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
gather_ids.extend(
|
||||
list(
|
||||
range(len(position_ids),
|
||||
len(position_ids) + 1 + self.max_draft_len)))
|
||||
len(position_ids) + 1 + self.runtime_draft_len)))
|
||||
position_ids.extend(
|
||||
list(
|
||||
range(past_seen_token_num,
|
||||
@ -1348,23 +1352,23 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# inputs
|
||||
# overlap scheduler can only support the speculative decoding
|
||||
# methods with a fixed number of draft tokens
|
||||
sequence_lengths.append(1 + self.max_draft_len)
|
||||
sequence_lengths.append(1 + self.runtime_draft_len)
|
||||
past_seen_token_num = request.max_beam_num_tokens - 1
|
||||
draft_lens.append(self.max_draft_len)
|
||||
draft_lens.append(self.runtime_draft_len)
|
||||
gather_ids.extend(
|
||||
list(
|
||||
range(len(position_ids),
|
||||
len(position_ids) + 1 + self.max_draft_len)))
|
||||
len(position_ids) + 1 + self.runtime_draft_len)))
|
||||
position_ids.extend(
|
||||
list(
|
||||
range(past_seen_token_num,
|
||||
past_seen_token_num + 1 + self.max_draft_len)))
|
||||
range(past_seen_token_num, past_seen_token_num + 1 +
|
||||
self.runtime_draft_len)))
|
||||
# previous tensor
|
||||
previous_batch_indices.append(previous_batch_idx)
|
||||
previous_pos_indices.extend([previous_batch_idx] *
|
||||
(1 + self.max_draft_len))
|
||||
(1 + self.runtime_draft_len))
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num +
|
||||
self.max_draft_len + 1)
|
||||
self.runtime_draft_len + 1)
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
request_ids.append(request.py_request_id)
|
||||
|
||||
@ -1441,21 +1445,21 @@ class PyTorchModelEngine(ModelEngine):
|
||||
previous_slots = previous_seq_slots_device()
|
||||
# previous input ids
|
||||
previous_batch_tokens = previous_batch_len * (
|
||||
1 + self.max_draft_len)
|
||||
1 + self.runtime_draft_len)
|
||||
new_tokens = new_tokens_device.transpose(
|
||||
0, 1)[previous_slots, :].flatten()
|
||||
self.input_ids_cuda[num_tokens:num_tokens +
|
||||
previous_batch_tokens].copy_(
|
||||
new_tokens, non_blocking=True)
|
||||
# previous draft tokens
|
||||
previous_batch_draft_tokens = previous_batch_len * self.max_draft_len
|
||||
previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len
|
||||
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
|
||||
previous_batch_draft_tokens].copy_(
|
||||
next_draft_tokens_device[
|
||||
previous_slots, :].flatten(),
|
||||
non_blocking=True)
|
||||
# prepare data for the preprocess inputs
|
||||
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
|
||||
kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1
|
||||
previous_pos_indices_host = torch.tensor(previous_pos_indices,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
@ -1480,8 +1484,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
extend_dummy_requests)
|
||||
self.previous_pos_id_offsets_cuda[
|
||||
(num_extend_reqeust_wo_dummy - previous_batch_len) *
|
||||
(1 + self.max_draft_len):num_extend_reqeust_wo_dummy *
|
||||
(1 + self.max_draft_len)].copy_(
|
||||
(1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy *
|
||||
(1 + self.runtime_draft_len)].copy_(
|
||||
new_tokens_lens_device[self.previous_pos_indices_cuda[
|
||||
0:previous_batch_tokens]],
|
||||
non_blocking=True)
|
||||
|
||||
@ -336,7 +336,7 @@ class TorchSampler(Sampler):
|
||||
if request.py_draft_logits is None:
|
||||
new_token = add_token(request, new_tokens, beam=self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop or len(request.py_draft_tokens) == 0:
|
||||
if stop or get_draft_token_length(request) == 0:
|
||||
return 0
|
||||
num_accepted = 0
|
||||
|
||||
@ -360,10 +360,10 @@ class TorchSampler(Sampler):
|
||||
request.py_draft_logits[0],
|
||||
generator=generator)
|
||||
target_probs = request.py_target_probs
|
||||
p = draft_probs[torch.arange(len(request.py_draft_tokens)),
|
||||
p = draft_probs[torch.arange(get_draft_token_length(request)),
|
||||
request.py_draft_tokens]
|
||||
q = target_probs[:-1]
|
||||
q = q[torch.arange(len(request.py_draft_tokens)),
|
||||
q = q[torch.arange(get_draft_token_length(request)),
|
||||
request.py_draft_tokens]
|
||||
accept_probs = torch.minimum(torch.ones(()), q / p)
|
||||
# Use deterministic random generation for multi-GPU consistency
|
||||
@ -374,7 +374,7 @@ class TorchSampler(Sampler):
|
||||
sample_last = True
|
||||
stop = False
|
||||
if rejected_indices.numel() == 0:
|
||||
num_initially_accepted = len(request.py_draft_tokens)
|
||||
num_initially_accepted = get_draft_token_length(request)
|
||||
sample_last = False
|
||||
else:
|
||||
num_initially_accepted = rejected_indices[0].item()
|
||||
@ -575,7 +575,7 @@ class TorchSampler(Sampler):
|
||||
logits = raw_logits[:sum_steps]
|
||||
# Collect steps per request for batched strategy
|
||||
steps_per_request = [
|
||||
1 + len(req.py_draft_tokens) for req in requests
|
||||
1 + get_draft_token_length(req) for req in requests
|
||||
]
|
||||
logits = self._apply_embedding_bias(logits, requests,
|
||||
steps_per_request)
|
||||
|
||||
@ -155,6 +155,8 @@ class AccuracyTask:
|
||||
spec_dec_algo = None
|
||||
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
|
||||
spec_dec_algo = llm.args.speculative_config.decoding_type
|
||||
if spec_dec_algo == 'AUTO':
|
||||
spec_dec_algo = 'NGram'
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not recognized speculative_config: {llm.args.speculative_config}."
|
||||
|
||||
@ -21,10 +21,10 @@ from tensorrt_llm import LLM
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
KvCacheConfig, MoeConfig, MTPDecodingConfig,
|
||||
NGramDecodingConfig, SamplingParams,
|
||||
TorchCompileConfig)
|
||||
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
|
||||
EagleDecodingConfig, KvCacheConfig, MoeConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
SamplingParams, TorchCompileConfig)
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper,
|
||||
@ -355,6 +355,23 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
task = JsonModeEval(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_hopper
|
||||
def test_auto_spec_decode(self):
|
||||
pytorch_config = {
|
||||
"cuda_graph_config":
|
||||
CudaGraphConfig(batch_sizes=[1, 32, 64], enable_padding=True)
|
||||
}
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.5)
|
||||
spec_config = AutoDecodingConfig()
|
||||
with LLM(model=self.MODEL_PATH,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config,
|
||||
max_batch_size=64) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user