[https://nvbugs/5441438][fix] Set correct draft length for the cuda graph dummy request (#6701)

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
Ziyi Xiong 2025-08-12 09:28:47 +08:00 committed by GitHub
parent ead89a0e40
commit b4fcd5f592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 28 deletions

View File

@ -453,6 +453,10 @@ class PyTorchModelEngine(ModelEngine):
else:
self.cache_indirection_attention = None
@property
def runtime_draft_len(self):
return self.max_draft_len if self.enable_spec_decode else 0
def set_lora_model_config(self, lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str]):
self.lora_model_config = LoraModelConfig(
@ -573,7 +577,7 @@ class PyTorchModelEngine(ModelEngine):
list(range(batch_size)), [num_tokens_per_request] *
batch_size if not is_gen else None,
is_gen=is_gen,
max_num_draft_tokens=self.max_draft_len)
max_num_draft_tokens=self.runtime_draft_len)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
@ -592,7 +596,7 @@ class PyTorchModelEngine(ModelEngine):
def get_autotune_warmup_request():
available_tokens = kv_cache_manager.get_num_available_tokens(
self.max_draft_len)
self.runtime_draft_len)
num_tokens_per_request = min(
min(available_tokens, self.max_seq_len - 1),
self.max_num_tokens)
@ -626,14 +630,14 @@ class PyTorchModelEngine(ModelEngine):
request_ids=list(range(full_len_request_num)),
token_nums=[num_tokens_per_request] * full_len_request_num,
is_gen=False,
max_num_draft_tokens=self.max_draft_len)
max_num_draft_tokens=self.runtime_draft_len)
if remaining_tokens > 0:
final_request = kv_cache_manager.add_dummy_requests(
request_ids=[full_len_request_num],
token_nums=[remaining_tokens],
is_gen=False,
max_num_draft_tokens=self.max_draft_len)
max_num_draft_tokens=self.runtime_draft_len)
requests += final_request
@ -680,7 +684,7 @@ class PyTorchModelEngine(ModelEngine):
# Disable cuda graph capture here so that we can properly capture it later
with self.no_cuda_graph():
available_tokens = kv_cache_manager.get_num_available_tokens(
self.max_draft_len)
self.runtime_draft_len)
warmup_batch_size = [1, self.batch_size // 2]
if self.batch_size < 2:
warmup_batch_size = [1]
@ -898,7 +902,7 @@ class PyTorchModelEngine(ModelEngine):
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
cuda_graph_dummy_request_ids,
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width)[0]
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
@ -1332,7 +1336,7 @@ class PyTorchModelEngine(ModelEngine):
gather_ids.extend(
list(
range(len(position_ids),
len(position_ids) + 1 + self.max_draft_len)))
len(position_ids) + 1 + self.runtime_draft_len)))
position_ids.extend(
list(
range(past_seen_token_num,
@ -1348,23 +1352,23 @@ class PyTorchModelEngine(ModelEngine):
# inputs
# overlap scheduler can only support the speculative decoding
# methods with a fixed number of draft tokens
sequence_lengths.append(1 + self.max_draft_len)
sequence_lengths.append(1 + self.runtime_draft_len)
past_seen_token_num = request.max_beam_num_tokens - 1
draft_lens.append(self.max_draft_len)
draft_lens.append(self.runtime_draft_len)
gather_ids.extend(
list(
range(len(position_ids),
len(position_ids) + 1 + self.max_draft_len)))
len(position_ids) + 1 + self.runtime_draft_len)))
position_ids.extend(
list(
range(past_seen_token_num,
past_seen_token_num + 1 + self.max_draft_len)))
range(past_seen_token_num, past_seen_token_num + 1 +
self.runtime_draft_len)))
# previous tensor
previous_batch_indices.append(previous_batch_idx)
previous_pos_indices.extend([previous_batch_idx] *
(1 + self.max_draft_len))
(1 + self.runtime_draft_len))
num_cached_tokens_per_seq.append(past_seen_token_num +
self.max_draft_len + 1)
self.runtime_draft_len + 1)
prompt_lengths.append(request.py_prompt_len)
request_ids.append(request.py_request_id)
@ -1441,21 +1445,21 @@ class PyTorchModelEngine(ModelEngine):
previous_slots = previous_seq_slots_device()
# previous input ids
previous_batch_tokens = previous_batch_len * (
1 + self.max_draft_len)
1 + self.runtime_draft_len)
new_tokens = new_tokens_device.transpose(
0, 1)[previous_slots, :].flatten()
self.input_ids_cuda[num_tokens:num_tokens +
previous_batch_tokens].copy_(
new_tokens, non_blocking=True)
# previous draft tokens
previous_batch_draft_tokens = previous_batch_len * self.max_draft_len
previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
previous_batch_draft_tokens].copy_(
next_draft_tokens_device[
previous_slots, :].flatten(),
non_blocking=True)
# prepare data for the preprocess inputs
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1
previous_pos_indices_host = torch.tensor(previous_pos_indices,
dtype=torch.int,
pin_memory=True)
@ -1480,8 +1484,8 @@ class PyTorchModelEngine(ModelEngine):
extend_dummy_requests)
self.previous_pos_id_offsets_cuda[
(num_extend_reqeust_wo_dummy - previous_batch_len) *
(1 + self.max_draft_len):num_extend_reqeust_wo_dummy *
(1 + self.max_draft_len)].copy_(
(1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy *
(1 + self.runtime_draft_len)].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
0:previous_batch_tokens]],
non_blocking=True)

View File

@ -336,7 +336,7 @@ class TorchSampler(Sampler):
if request.py_draft_logits is None:
new_token = add_token(request, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop or len(request.py_draft_tokens) == 0:
if stop or get_draft_token_length(request) == 0:
return 0
num_accepted = 0
@ -360,10 +360,10 @@ class TorchSampler(Sampler):
request.py_draft_logits[0],
generator=generator)
target_probs = request.py_target_probs
p = draft_probs[torch.arange(len(request.py_draft_tokens)),
p = draft_probs[torch.arange(get_draft_token_length(request)),
request.py_draft_tokens]
q = target_probs[:-1]
q = q[torch.arange(len(request.py_draft_tokens)),
q = q[torch.arange(get_draft_token_length(request)),
request.py_draft_tokens]
accept_probs = torch.minimum(torch.ones(()), q / p)
# Use deterministic random generation for multi-GPU consistency
@ -374,7 +374,7 @@ class TorchSampler(Sampler):
sample_last = True
stop = False
if rejected_indices.numel() == 0:
num_initially_accepted = len(request.py_draft_tokens)
num_initially_accepted = get_draft_token_length(request)
sample_last = False
else:
num_initially_accepted = rejected_indices[0].item()
@ -575,7 +575,7 @@ class TorchSampler(Sampler):
logits = raw_logits[:sum_steps]
# Collect steps per request for batched strategy
steps_per_request = [
1 + len(req.py_draft_tokens) for req in requests
1 + get_draft_token_length(req) for req in requests
]
logits = self._apply_embedding_bias(logits, requests,
steps_per_request)

View File

@ -155,6 +155,8 @@ class AccuracyTask:
spec_dec_algo = None
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
spec_dec_algo = llm.args.speculative_config.decoding_type
if spec_dec_algo == 'AUTO':
spec_dec_algo = 'NGram'
else:
raise ValueError(
f"Not recognized speculative_config: {llm.args.speculative_config}."

View File

@ -21,10 +21,10 @@ from tensorrt_llm import LLM
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, SamplingParams,
TorchCompileConfig)
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
SamplingParams, TorchCompileConfig)
from tensorrt_llm.quantization import QuantAlgo
from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper,
@ -355,6 +355,23 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_hopper
def test_auto_spec_decode(self):
pytorch_config = {
"cuda_graph_config":
CudaGraphConfig(batch_sizes=[1, 32, 64], enable_padding=True)
}
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.5)
spec_config = AutoDecodingConfig()
with LLM(model=self.MODEL_PATH,
**pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
max_batch_size=64) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.2-1B"