mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[nvbug/5319281][fix] Stop drafting when we hit the draft model's max seq len (#4879)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
3d87770e15
commit
25aa3881d7
@ -1080,13 +1080,22 @@ class PyTorchModelEngine(ModelEngine):
|
||||
return model
|
||||
|
||||
def _init_max_seq_len(self):
|
||||
inferred_max_seq_len = self.model.infer_max_seq_len()
|
||||
if self.max_seq_len is None:
|
||||
inferred_max_seq_len = self.model.infer_max_seq_len()
|
||||
logger.info(
|
||||
f"max_seq_len is not specified, using inferred value {inferred_max_seq_len}"
|
||||
)
|
||||
self.max_seq_len = inferred_max_seq_len
|
||||
|
||||
elif inferred_max_seq_len < self.max_seq_len:
|
||||
# NOTE: py_executor_creator makes sure that the executor uses this
|
||||
# smaller value as its max_seq_len too.
|
||||
logger.warning(
|
||||
f"Specified {self.max_seq_len=} is larger than what the model can support "
|
||||
f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. "
|
||||
)
|
||||
self.max_seq_len = inferred_max_seq_len
|
||||
|
||||
def _init_max_num_tokens(self):
|
||||
# Modified from tensorrt_llm/_common.py check_max_num_tokens
|
||||
if self.max_num_tokens is None:
|
||||
|
||||
@ -1757,6 +1757,16 @@ class PyExecutor:
|
||||
# No space for draft tokens.
|
||||
continue
|
||||
|
||||
# Stop drafting when we hit the max seqlen. We still need dummy draft
|
||||
# tokens attached to the requests to make sure everything works properly
|
||||
# with CUDA graph. These dummy tokens are already added by
|
||||
# _prepare_draft_requests to make the KV cache/scheduler aware of the fact
|
||||
# that we want to do spec decoding, so no need to do anything else here.
|
||||
# This makes the perf for this case suboptimal, but that's OK - this is
|
||||
# a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
|
||||
if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len:
|
||||
continue
|
||||
|
||||
num_draft_tokens = len(
|
||||
request.py_last_draft_tokens
|
||||
) if request.py_last_draft_tokens is not None else 0
|
||||
|
||||
@ -231,7 +231,9 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
pytorch_backend_config,
|
||||
batch_size=executor_config.max_batch_size,
|
||||
max_num_tokens=executor_config.max_num_tokens,
|
||||
max_seq_len=model_engine.max_seq_len,
|
||||
# Note: The draft model engine will infer its own max_seq_len.
|
||||
# We'll stop drafting when we hit the max.
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
mapping=mapping,
|
||||
attn_runtime_features=attn_runtime_features,
|
||||
dist=dist,
|
||||
|
||||
@ -46,11 +46,13 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
|
||||
llm_spec = LLM(
|
||||
model=target_model_dir,
|
||||
**pytorch_config,
|
||||
# This max_seq_len is larger than the one specified
|
||||
# in the llama 3 8B eagle's config. We want to make sure
|
||||
# that the draft model won't go above its max in warmup
|
||||
# in this test.
|
||||
max_seq_len=8192,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config,
|
||||
# TODO: https://nvbugspro.nvidia.com/bug/5319281
|
||||
max_num_tokens=2048,
|
||||
max_seq_len=2048)
|
||||
speculative_config=spec_config)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=32,
|
||||
@ -88,9 +90,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
|
||||
|
||||
llm_ref = LLM(model=target_model_dir,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_num_tokens=2048,
|
||||
max_seq_len=2048)
|
||||
kv_cache_config=kv_cache_config)
|
||||
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
generated_text_ref = [result.outputs[0].text for result in results_ref]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user