Reintroduce with perf fixes: feature: unify new_tokens format sample state to trtllm samper tokens format (#5513)

58a8a8f - these changes were previously merged to main here.
6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue).
This PR is meant to re-merge these changes along with a fix to prevent the regression.

The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes.

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
Netanel Haber 2025-06-30 21:58:59 +03:00 committed by GitHub
parent f28cd3056e
commit 6ee94c7ac8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 461 additions and 498 deletions

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import torch
from torch._prims_common import DeviceLikeType
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
from tensorrt_llm._utils import nvtx_range
from ...._utils import mpi_rank, mpi_world_size
@ -256,6 +257,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
ad_config: LlmArgs = executor_config.pytorch_backend_config
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
# some derivative properties
max_draft_tokens = (
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
@ -272,7 +274,13 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
max_seq_len=ad_config.max_seq_len,
max_batch_size=ad_config.max_batch_size,
)
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences)
resource_manager = ResourceManager(
{
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
}
)
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)
# scheduling
@ -287,10 +295,14 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
# correctly for models as needed.
sampler = TorchSampler(
sampler_args = TorchSampler.Args(
max_seq_len=ad_config.max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=ad_config.mixed_sampler,
)
sampler = TorchSampler(sampler_args)
# creating the executor object
py_executor = PyExecutor(
@ -299,6 +311,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
model_engine=engine,
sampler=sampler,
dist=mpi_dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,
max_batch_size=ad_config.max_batch_size,

View File

@ -26,8 +26,7 @@ from .py_executor import PyExecutor
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
PeftCacheManager, ResourceManager,
ResourceManagerType)
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
TRTLLMSampler)
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
from .seq_slot_manager import SeqSlotManager
@ -514,6 +513,7 @@ def create_py_executor_instance(
sampler=sampler,
drafter=drafter,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
@ -525,27 +525,44 @@ def create_py_executor_instance(
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
def instantiate_sampler(model_engine: PyTorchModelEngine,
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
*, max_seq_len: int, mixed_sampler: bool):
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_draft_tokens = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_tokens)
return TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=mixed_sampler,
)
def instantiate_sampler(engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
pytorch_backend_config: PyTorchConfig,
mapping: Mapping):
sampler_args = create_torch_sampler_args(
executor_config,
mapping,
max_seq_len=engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
if mapping.cp_config.get('cp_type') == 'star_attention':
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchStarAttentionSampler(max_seq_len=model_engine.max_seq_len)
spec_config = model_engine.spec_config
if spec_config is not None and spec_config.spec_dec_mode.has_spec_decoder():
return get_spec_decoder(max_seq_len=model_engine.max_seq_len,
spec_config=spec_config)
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
):
return get_spec_decoder(sampler_args, engine.spec_config)
if pytorch_backend_config.enable_trtllm_sampler:
return TRTLLMSampler(executor_config, model_engine.model,
model_engine.dtype, mapping,
get_decoding_mode(executor_config),
decoding_mode = get_decoding_mode(executor_config)
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
mapping, decoding_mode,
pytorch_backend_config.disable_overlap_scheduler)
elif not model_engine.model.model_config.is_generation:
if not engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
return EarlyStopSampler()
return TorchSampler(max_seq_len=model_engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
return TorchSampler(sampler_args)
def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:

View File

@ -1,4 +1,3 @@
import itertools
import math
from typing import List, Optional
@ -52,8 +51,7 @@ class GuidedDecoder:
def build(self, scheduled_requests: ScheduledRequests,
resource_manager: SeqSlotManager) -> None:
for llm_req in itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
for llm_req in scheduled_requests.all_requests():
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
@ -84,9 +82,7 @@ class GuidedDecoder:
torch.cuda.current_stream().wait_stream(self._stream)
batched_logits, batched_bitmask = [], []
for i, llm_req in enumerate(
itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)):
for i, llm_req in enumerate(scheduled_requests.all_requests()):
if llm_req.guided_decoding_params is None:
continue
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:

View File

@ -254,6 +254,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
exclude_last_generation_logits: bool = False,
return_perf_metrics: bool = False,
stop_words_list: list[list[int]] | None = None,
is_draft: bool = False,
**kwargs):
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
None)
@ -288,6 +289,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_return_context_logits = return_context_logits
self.py_return_generation_logits = return_generation_logits
self.py_return_logits_device_memory = return_logits_device_memory
self.py_is_draft = is_draft
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
# currently, keep py_stop_words_list as python list, rather than tensor.

View File

@ -4,7 +4,6 @@ import functools
import gc
import glob
import inspect
import itertools
import math
import multiprocessing
import os
@ -21,6 +20,7 @@ import torch
import torch._dynamo.config
import tensorrt_llm.bindings.internal.userbuffers as ub
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank,
@ -319,6 +319,7 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
class PyTorchModelEngine(ModelEngine):
BEAM_WIDTH = 1
def __init__(
self,
@ -659,13 +660,12 @@ class PyTorchModelEngine(ModelEngine):
return result
@contextlib.contextmanager
def release_batch(result):
def release_batch(result: ScheduledRequests | None):
try:
yield result
finally:
if result is not None:
for req in itertools.chain(result.generation_requests,
result.context_requests):
for req in result.all_requests():
kv_cache_manager.free_resources(req)
if spec_resource_manager is not None:
spec_resource_manager.free_resources(req)
@ -1153,7 +1153,15 @@ class PyTorchModelEngine(ModelEngine):
draft_lens = []
mrope_config = defaultdict(list)
batch_idx = 0
mtp_batch_idx = 0 # Temporary: MTP (and Eagle3OneModel) remain the only samplers to index new_tokens serially
def py_batch_idx(request: LlmRequest) -> int:
if not self.without_logits:
return request.seq_slot
nonlocal mtp_batch_idx
batch_idx = mtp_batch_idx
mtp_batch_idx += 1
return batch_idx
for request in scheduled_requests.context_requests:
request_ids.append(request.py_request_id)
@ -1184,10 +1192,9 @@ class PyTorchModelEngine(ModelEngine):
) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin
mrope_config['mrope_rotary_cos_sin'].append(
mrope_rotary_cos_sin.to('cuda', non_blocking=True))
request.py_batch_idx = batch_idx
batch_idx += 1
request.py_batch_idx = py_batch_idx(request)
num_ctx_requests = batch_idx
num_ctx_requests = len(scheduled_requests.context_requests)
num_ctx_tokens = len(input_ids)
new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None
if new_tensors_device is not None:
@ -1227,7 +1234,7 @@ class PyTorchModelEngine(ModelEngine):
assert spec_dec_mode.support_overlap_scheduler(
), f"{self.spec_config.spec_dec_name} does not support overlap scheduler"
# will contain previous batch incices of generation requests
# will contain previous batch indices of generation requests
previous_batch_indices = []
previous_pos_indices = []
for request in extend_requests:
@ -1267,13 +1274,11 @@ class PyTorchModelEngine(ModelEngine):
num_cached_tokens_per_seq.append(past_seen_token_num)
request_ids.append(request.py_request_id)
# update batch index
request.py_batch_idx = batch_idx
batch_idx += 1
request.py_batch_idx = py_batch_idx(request)
else:
# update batch index
previous_batch_idx = request.py_batch_idx
request.py_batch_idx = batch_idx
batch_idx += 1
request.py_batch_idx = py_batch_idx(request)
# inputs
# overlap scheduler can only support the speculative decoding
# methods with a fixed number of draft tokens
@ -1324,12 +1329,21 @@ class PyTorchModelEngine(ModelEngine):
prompt_lengths.append(request.py_prompt_len)
draft_lens.append(0)
request.py_batch_idx = batch_idx
batch_idx += 1
request.py_batch_idx = py_batch_idx(request)
previous_batch_len = len(previous_batch_indices)
def previous_seq_slots_device():
previous_batch_indices_host = torch.tensor(previous_batch_indices,
dtype=torch.int,
pin_memory=True)
previous_slots = self.previous_batch_indices_cuda[:
previous_batch_len]
previous_slots.copy_(previous_batch_indices_host, non_blocking=True)
return previous_slots
num_tokens = len(input_ids)
num_draft_tokens = len(draft_tokens)
previous_batchs = len(previous_batch_indices)
num_requests = len(request_ids)
total_num_tokens = len(position_ids)
assert total_num_tokens <= self.max_num_tokens, (
@ -1347,67 +1361,55 @@ class PyTorchModelEngine(ModelEngine):
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens,
non_blocking=True)
if next_draft_tokens_device is not None:
if len(previous_batch_indices) > 0:
previous_batch_indices = torch.tensor(previous_batch_indices,
dtype=torch.int,
pin_memory=True)
self.previous_batch_indices_cuda[:previous_batchs].copy_(
previous_batch_indices, non_blocking=True)
if previous_batch_len > 0:
previous_slots = previous_seq_slots_device()
# previous input ids
previous_batch_tokens = previous_batchs * (1 +
self.max_draft_len)
self.input_ids_cuda[
num_tokens:num_tokens +
previous_batch_tokens].copy_(new_tokens_device[
self.previous_batch_indices_cuda[:previous_batchs], :].
flatten(),
non_blocking=True)
previous_batch_tokens = previous_batch_len * (
1 + self.max_draft_len)
new_tokens = new_tokens_device[previous_slots, :].flatten()
self.input_ids_cuda[num_tokens:num_tokens +
previous_batch_tokens].copy_(
new_tokens, non_blocking=True)
# previous draft tokens
previous_batch_draft_tokens = previous_batchs * self.max_draft_len
self.draft_tokens_cuda[
num_draft_tokens:num_draft_tokens +
previous_batch_draft_tokens].copy_(next_draft_tokens_device[
self.previous_batch_indices_cuda[:previous_batchs], :].
flatten(),
non_blocking=True)
previous_batch_draft_tokens = previous_batch_len * self.max_draft_len
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
previous_batch_draft_tokens].copy_(
next_draft_tokens_device[
previous_slots, :].flatten(),
non_blocking=True)
# prepare data for the preprocess inputs
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
previous_pos_indices = torch.tensor(previous_pos_indices,
dtype=torch.int,
pin_memory=True)
previous_pos_indices_host = torch.tensor(previous_pos_indices,
dtype=torch.int,
pin_memory=True)
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
previous_pos_indices, non_blocking=True)
previous_pos_indices_host, non_blocking=True)
self.previous_pos_id_offsets_cuda[
0:previous_batch_tokens].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
0:previous_batch_tokens]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_(
kv_len_offsets_device[
self.previous_batch_indices_cuda[:previous_batchs]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[0:previous_batch_len].copy_(
kv_len_offsets_device[previous_slots], non_blocking=True)
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda[
previous_batch_tokens:num_requests *
(1 + self.max_draft_len)] *= 0
self.previous_kv_lens_offsets_cuda[
previous_batchs:num_requests] *= 0
previous_batch_len:num_requests] *= 0
else:
# change the data to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda *= 0
self.previous_kv_lens_offsets_cuda *= 0
elif new_tokens_device is not None:
previous_batch_tokens = len(previous_batch_indices)
previous_batch_indices = torch.tensor(previous_batch_indices,
dtype=torch.int,
pin_memory=True)
self.previous_batch_indices_cuda[:previous_batch_tokens].copy_(
previous_batch_indices, non_blocking=True)
self.input_ids_cuda[num_tokens:num_tokens + previous_batchs].copy_(
new_tokens_device[
self.previous_batch_indices_cuda[:previous_batchs]],
non_blocking=True)
seq_slots_device = previous_seq_slots_device()
max_draft_len = max(draft_lens)
new_tokens = new_tokens_device[:max_draft_len + 1,
seq_slots_device, :self.BEAM_WIDTH]
self.input_ids_cuda[num_tokens:num_tokens +
previous_batch_len].copy_(new_tokens.flatten(),
non_blocking=True)
position_ids = torch.tensor(position_ids,
dtype=torch.int,
@ -1645,7 +1647,6 @@ class PyTorchModelEngine(ModelEngine):
# for star attention, we need customized block ids
block_ids_per_seq = []
num_cached_tokens_per_seq = []
output_token_idx = 0
for request in scheduled_requests.context_requests:
request_ids.append(request.py_request_id)
prompt_lengths.append(request.py_prompt_len)
@ -1702,8 +1703,6 @@ class PyTorchModelEngine(ModelEngine):
sequence_lengths.append(len(input_id))
block_ids_per_seq.extend([all_cache_indices])
num_cached_tokens_per_seq.append(past_seen_token_num)
request.output_token_idx = output_token_idx
output_token_idx += 1
num_contexts = len(sequence_lengths)
for request in scheduled_requests.context_requests:
ctx_iter = request.ctx_iters
@ -1743,8 +1742,6 @@ class PyTorchModelEngine(ModelEngine):
sequence_lengths.append(len(input_id))
block_ids_per_seq.extend([all_cache_indices])
num_cached_tokens_per_seq.append(past_seen_token_num)
request.output_token_idx = output_token_idx
output_token_idx += 1
num_queries = len(sequence_lengths) - num_contexts
# Requests with draft tokens are treated like extend requests.
@ -1802,8 +1799,6 @@ class PyTorchModelEngine(ModelEngine):
position_ids.append(last_query_pos_id + request.gen_iters + 1)
block_ids_per_seq.extend([all_cache_indices])
num_cached_tokens_per_seq.append(past_seen_token_num)
request.output_token_idx = output_token_idx
output_token_idx += 1
num_tokens = len(input_ids)
assert num_tokens <= self.max_num_tokens, (
@ -2171,9 +2166,7 @@ class PyTorchModelEngine(ModelEngine):
num_ctx_req = len(scheduled_requests.context_requests)
logits_tensor = outputs["logits"]
for idx, request in enumerate(
itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)):
for idx, request in enumerate(scheduled_requests.all_requests()):
logits_processors = getattr(request, "py_logits_post_processors",
None)
if not logits_processors:

View File

@ -11,12 +11,12 @@ import traceback
import weakref
from collections import namedtuple
from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
import torch
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
is_trace_enabled, nvtx_range, trace_func)
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
@ -36,7 +36,7 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, executor_request_to_llm_request)
from .model_engine import ModelEngine
from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler
from .scheduler import ScheduledRequests
from .scheduler import RequestScheduler, ScheduledRequests
# Environment variable to specify iteration ranges for profiling start/stop.
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
@ -163,10 +163,11 @@ class PyExecutor:
def __init__(self,
resource_manager,
scheduler,
scheduler: RequestScheduler,
model_engine: ModelEngine,
sampler: Sampler,
dist: Distributed,
max_num_sequences: int,
drafter: Drafter = None,
disable_overlap_scheduler: bool = False,
max_input_len: int = 2048,
@ -271,11 +272,13 @@ class PyExecutor:
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
self.event_loop = trace_func(self.event_loop)
if self.draft_model_engine is not None and self.event_loop.__name__ != self._executor_loop.__name__:
raise NotImplementedError(
"Drafting is not supported for selected executor loop. "
"Please disable disagg/pipeline parallelism/overlap scheduler.")
if self.draft_model_engine is not None:
if self.event_loop.__name__ != self._executor_loop.__name__:
raise NotImplementedError(
"Drafting is not supported for selected executor loop. "
"Please disable disagg/pipeline parallelism/overlap scheduler."
)
self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences)
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
self.worker_started = False
@ -757,7 +760,7 @@ class PyExecutor:
"cpu", non_blocking=True)
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
sample_state.logits_host = logits_host
sample_state.host.logits = logits_host
self._update_request_states(scheduled_batch)
if self.enable_iter_perf_stats:
@ -788,7 +791,6 @@ class PyExecutor:
# Receive tokens from previous pp rank (w.r.t model forward direction)
(
logits,
sample_state.log_probs,
sample_state.host,
) = self.dist.recv_object(
src=self.dist.prev_pp_rank,
@ -796,8 +798,9 @@ class PyExecutor:
)
if logits is not None:
logits_host = torch.from_numpy(logits)
sample_state.logits_host = logits_host
sample_state.logits = logits_host.to(self.device_id)
sample_state.host.logits = logits_host
sample_state.device.logits = logits_host.to(
self.device_id)
else:
torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp")
sample_state.sampler_event.synchronize()
@ -807,16 +810,16 @@ class PyExecutor:
if not self.dist.is_second_last_pp_rank:
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].Wait()
needs_logits = (
self._need_return_logits(scheduled_batch)
or (self._need_return_log_probs(scheduled_batch)
and sample_state.host.log_probs is not None))
serialized_logits = sample_state.host.logits.numpy(
) if needs_logits else None
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
(
sample_state.logits_host.numpy() if
self._need_return_logits(scheduled_batch) or
(self._need_return_log_probs(
scheduled_batch)
and sample_state.log_probs is not None)
else None,
sample_state.log_probs,
serialized_logits,
sample_state.host,
),
dest=self.dist.next_pp_rank,
@ -1754,41 +1757,35 @@ class PyExecutor:
input_tokens = spec_config.get_draft_model_prompt(
request.get_tokens()[beam_idx])
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
# This is the first time the draft model is seeing this request.
# Prepare a context request. We discard the first token and take
# the newly decoded one - this is the convention for EAGLE 2 and 3.
assert num_draft_tokens == 0
new_request = LlmRequest(
def create_new_request(input_tokens):
return LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
input_tokens=input_tokens,
sampling_config=request.sampling_config,
return_perf_metrics=request.return_perf_metrics,
is_streaming=False)
is_streaming=False,
is_draft=True)
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
# This is the first time the draft model is seeing this request.
# Prepare a context request. We discard the first token and take
# the newly decoded one - this is the convention for EAGLE 2 and 3.
assert num_draft_tokens == 0
new_request = create_new_request(input_tokens)
draft_batch.context_requests.append(new_request)
elif num_accepted_tokens == 0:
new_request = LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
input_tokens=input_tokens[:-1],
sampling_config=request.sampling_config,
return_perf_metrics=request.return_perf_metrics,
is_streaming=False)
new_request = create_new_request(input_tokens[:-1])
# Explicitly add the last token so get_last_tokens() returns
# the right value
new_request.add_new_token(input_tokens[-1], beam_idx)
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
draft_batch.generation_requests.append(new_request)
else:
new_request = LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
input_tokens=input_tokens,
sampling_config=request.sampling_config,
return_perf_metrics=request.return_perf_metrics,
is_streaming=False)
new_request = create_new_request(input_tokens)
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1
@ -1807,16 +1804,19 @@ class PyExecutor:
@nvtx_range("_prepare_draft_tokens")
def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests):
if not self.draft_model_engine:
raise ValueError("Draft model engine is not set")
try:
draft_batch = self._prepare_draft_batch(scheduled_requests)
if draft_batch.batch_size == 0:
return
self.draft_seq_slot_manager.prepare_resources(draft_batch)
req_id_to_old_request = {
req.py_request_id: req
for req in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)
for req in scheduled_requests.all_requests()
}
# Disable cuda graph for the 1st draft model forward
@ -1838,8 +1838,7 @@ class PyExecutor:
def _process_decoded_tokens(draft_batch):
new_requests = []
for req in chain(draft_batch.context_requests,
draft_batch.generation_requests):
for req in draft_batch.all_requests():
target_model_req = req_id_to_old_request[req.py_request_id]
target_model_req.py_draft_tokens.append(
req.get_last_tokens(0))
@ -1847,6 +1846,8 @@ class PyExecutor:
target_model_req.py_draft_tokens
) < target_model_req.py_draft_pages_allocated:
new_requests.append(req)
else:
self.draft_seq_slot_manager.free_resources(req)
return new_requests
@ -2087,14 +2088,12 @@ class PyExecutor:
def _add_inflight_ids(self, scheduled_requests):
"""Add reqids of current requests to self.inflight_req_ids."""
for req in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
for req in scheduled_requests.all_requests():
self.inflight_req_ids.insert(req.request_id)
def _remove_inflight_ids(self, scheduled_requests):
"""Remove reqids of current requests from self.inflight_req_ids."""
for req in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
for req in scheduled_requests.all_requests():
self.inflight_req_ids.erase(req.request_id)
def _should_exclude_last_generation_logits(self) -> bool:

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal
import torch
@ -27,9 +28,11 @@ from .llm_request import LlmRequest, LlmRequestState
from .scheduler import ScheduledRequests
@dataclass(frozen=True, kw_only=True)
@dataclass(kw_only=True)
class SampleStateTensors:
new_tokens: torch.Tensor
logits: torch.Tensor | None = None
log_probs: torch.Tensor | None = None
def values(self):
return vars(self).values()
@ -39,13 +42,6 @@ class SampleStateTensors:
class SampleState:
scheduled_requests: ScheduledRequests
logits: torch.Tensor = None
logits_host: torch.Tensor = None
# Set when decode_async() has evaluated these to avoid computing again in update_requests()
# log_probs[request_idx][token_idx]
log_probs: list[list[float] | None] | None = None
device: SampleStateTensors = None
host: SampleStateTensors = None
@ -77,10 +73,12 @@ class EarlyStopSampler(Sampler):
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleState:
return SampleState(scheduled_requests=scheduled_requests,
logits=model_outputs['logits'])
host = SampleStateTensors(logits=model_outputs['logits'],
new_tokens=torch.empty(0))
return SampleState(scheduled_requests=scheduled_requests, host=host)
def update_requests(self, state: SampleState) -> None:
assert isinstance(state, SampleState)
scheduled_requests = state.scheduled_requests
assert (not scheduled_requests.generation_requests)
for idx, request in enumerate(scheduled_requests.context_requests):
@ -88,7 +86,7 @@ class EarlyStopSampler(Sampler):
# NOTE: This is a hack: set finish reason manually and set the beam 0
request.set_finished_reason(FinishReason.LENGTH, 0)
if request.py_return_context_logits:
logits = state.logits[idx]
logits = state.host.logits[idx]
if logits.ndim == 1:
# For BERT: Add axis to be compatible with LogitsStorage
# (LogitsStorage will interpret this dim as the prompt_len which
@ -104,8 +102,6 @@ def top_k_sampling_batch(logits, top_k=50):
# logits should be 2D [batch_size, vocab_size]
batch_size, vocab_size = logits.size()
raw_probs = torch.softmax(logits, dim=-1)
# get first top_k logits of each sample and their indices
values, indices = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
@ -115,24 +111,18 @@ def top_k_sampling_batch(logits, top_k=50):
torch.full_like(logits, float('-inf')), logits)
# compute probability distribution
probs = torch.softmax(logits, dim=-1)
softmax = torch.softmax(logits, dim=-1)
# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs
next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1)
return next_tokens, softmax
def top_p_sampling_batch(logits, top_p=0.9):
def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
logits_dim = logits.dim()
if logits_dim == 1:
logits = logits.unsqueeze(0)
# logits should be 2D [batch_size, vocab_size]
batch_size, vocab_size = logits.size()
raw_probs = torch.softmax(logits, dim=-1)
assert logits_dim == 2, "logits should be 2D [batch_size, vocab_size]"
# sort the logits of each sample in descending order
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
@ -152,46 +142,96 @@ def top_p_sampling_batch(logits, top_p=0.9):
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# compute probability distribution
probs = torch.softmax(logits, dim=-1)
softmax = torch.softmax(logits, dim=-1)
# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs
next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1)
return next_tokens, softmax
def greedy_search_sampling_batch(logits):
raw_probs = torch.softmax(logits, dim=-1)
next_tokens = torch.argmax(logits, dim=-1)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs
softmax = torch.softmax(logits, dim=-1)
return next_tokens, softmax
def decode_single_request(request: LlmRequest, logits):
assert logits.dim(
) == 2 and logits.shape[0] == 1, "logits should have shape [1, vocab_size]"
TopK = tuple[Literal["top_k"], int]
TopP = tuple[Literal["top_p"], float]
Greedy = tuple[Literal["greedy"], None]
GREEDY: Greedy = ("greedy", None)
Strategy = TopK | TopP | Greedy
def request_strategy(request: LlmRequest) -> Strategy:
if request.sampling_config.top_p is not None and len(
request.sampling_config.top_p) > 0:
next_tokens, log_probs = top_p_sampling_batch(
logits, request.sampling_config.top_p[0])
return ("top_p", request.sampling_config.top_p[0])
elif request.sampling_config.top_k is not None and len(
request.sampling_config.top_k) > 0:
next_tokens, log_probs = top_k_sampling_batch(
logits, request.sampling_config.top_k[0])
return ("top_k", request.sampling_config.top_k[0])
else:
next_tokens, log_probs = greedy_search_sampling_batch(logits)
return next_tokens, log_probs
return ("greedy", None)
def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]:
return [request_strategy(req) for req in requests]
def sample(strategy: Strategy, logits: torch.Tensor):
match strategy:
case ("top_k", top_k):
return top_k_sampling_batch(logits, top_k)
case ("top_p", top_p):
return top_p_sampling_batch(logits, top_p)
case ("greedy", None):
return greedy_search_sampling_batch(logits)
def add_token(request: LlmRequest,
new_tokens: torch.Tensor,
*,
beam: int,
step: int = 0) -> int:
seq_slot = request.seq_slot
assert seq_slot is not None
new_token = int(new_tokens[step, request.seq_slot, beam])
request.add_new_token(new_token, beam)
return new_token
class TorchSampler(Sampler):
BEAM = 0
MAX_BEAM_WIDTH = BEAM + 1
def __init__(self, max_seq_len: int, mixed_sampler: bool = False):
self.max_seq_len = max_seq_len
self.mixed_sampler = mixed_sampler
@dataclass(frozen=True, kw_only=True)
class Store:
new_tokens: torch.Tensor
"""Shape: See cpp DecoderState.getAllNewTokens()"""
@dataclass(frozen=True, kw_only=True)
class Args:
max_seq_len: int
max_draft_tokens: int
max_num_sequences: int
max_beam_width: int
mixed_sampler: bool
def __init__(self, args: Args):
self.max_seq_len = args.max_seq_len
self.mixed_sampler = args.mixed_sampler
self.max_tokens = args.max_draft_tokens + 1
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
self.num_seq_slots = args.max_num_sequences
# AutoDeploy build creates the sampler in inference mode,
# which would disallow in-place mutating of new_tokens.
# So, we temporarily exit inference mode.
with torch.inference_mode(False):
new_tokens = torch.zeros(
(self.max_tokens, self.num_seq_slots, self.MAX_BEAM_WIDTH),
dtype=torch.int,
device='cuda')
self.store = self.Store(new_tokens=new_tokens)
def _meet_max_token_stop_criteria(self, request: LlmRequest,
num_tokens: int):
@ -199,7 +239,8 @@ class TorchSampler(Sampler):
>= request.py_max_new_tokens) or (num_tokens
>= self.max_seq_len)
def _meet_stop_token_criteria(self, request: LlmRequest):
@staticmethod
def _meet_stop_token_criteria(request: LlmRequest):
if request.py_stop_words_list:
assert isinstance(
request.py_stop_words_list,
@ -217,233 +258,197 @@ class TorchSampler(Sampler):
return True
return False
def _handle_stop_criteria(self, request: LlmRequest, new_token: int,
num_tokens: int, beam_idx: int) -> bool:
def _handle_stop_criteria(self, request: LlmRequest, new_token: int, *,
beam: int) -> bool:
"""Handle stop criteria and set appropriate finish reasons and state.
Returns True if generation should stop."""
if new_token == request.py_end_id:
request.state = LlmRequestState.GENERATION_COMPLETE
request.set_finished_reason(FinishReason.END_ID, beam_idx)
request.finish_by_reason(FinishReason.END_ID)
return True
num_tokens = request.get_num_tokens(beam)
if self._meet_max_token_stop_criteria(request, num_tokens):
request.state = LlmRequestState.GENERATION_COMPLETE
request.set_finished_reason(FinishReason.LENGTH, beam_idx)
request.finish_by_reason(FinishReason.LENGTH)
return True
if self._meet_stop_token_criteria(request):
request.state = LlmRequestState.GENERATION_COMPLETE
request.set_finished_reason(FinishReason.STOP_WORDS, beam_idx)
request.finish_by_reason(FinishReason.STOP_WORDS)
return True
return False
def update_requests(self, state: SampleState) -> None:
if state.sampler_event:
state.sampler_event.synchronize()
new_tokens_list = state.host.new_tokens.tolist()
scheduled_requests = state.scheduled_requests
request_idx = 0
token_idx = 0
beam_idx = 0
def advance_idx(num_tokens=1):
nonlocal request_idx, token_idx
request_idx += 1
token_idx += num_tokens
def handle_logits(request: LlmRequest, tokens: list[int], count=1):
if state.logits is None:
return
if not request.py_return_generation_logits and not request.py_return_log_probs:
return
current_slice = slice(token_idx, token_idx + count)
current_logits = state.logits[current_slice]
request.py_result.append_generation_logits(current_logits)
if not request.py_return_log_probs:
return
if state.log_probs:
log_probs = state.log_probs[request_idx]
else:
_, log_probs = greedy_search_sampling_batch(current_logits)
token_log_probs = [{
token: Logprob(logprob=logprob, rank=1)
} for token, logprob in zip(tokens, log_probs.tolist())]
request.py_result.append_log_probs([token_log_probs])
if hasattr(scheduled_requests, 'chunked_requests'):
request_idx += len(scheduled_requests.chunked_requests)
token_idx += len(scheduled_requests.chunked_requests)
for request in scheduled_requests.context_requests:
if request.context_remaining_length != 0:
advance_idx()
continue
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[token_idx]
num_tokens = request.add_new_token(new_token, beam_idx)
self._handle_stop_criteria(request, new_token, num_tokens,
beam_idx)
handle_logits(request, [new_token])
request.py_decoding_iter += 1
advance_idx()
extend_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:
if len(request.py_draft_tokens) > 0:
extend_requests.append(request)
else:
generation_requests.append(request)
for request in extend_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[token_idx]
num_tokens = request.add_new_token(new_token, beam_idx)
if self._handle_stop_criteria(request, new_token, num_tokens,
beam_idx):
continue
# Accept draft tokens (if we have any) if and only if they match the new
# token exactly.
num_accepted = 0
new_tokens = [new_token]
for draft_token in request.py_draft_tokens:
if draft_token != new_token:
# Reject.
break
num_accepted += 1
new_token = new_tokens_list[token_idx + num_accepted]
num_tokens = request.add_new_token(new_token, beam_idx)
new_tokens.append(num_tokens) # `num_tokens`->`new_token`
if self._handle_stop_criteria(request, new_token,
num_tokens, beam_idx):
break
handle_logits(request, new_tokens, num_accepted)
request.py_decoding_iter += 1
request.py_num_accepted_draft_tokens = num_accepted
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
advance_idx(len(request.py_draft_tokens) + 1)
for request in generation_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[token_idx]
num_tokens = request.add_new_token(new_token, beam_idx)
self._handle_stop_criteria(request, new_token, num_tokens,
beam_idx)
handle_logits(request, [new_token])
request.py_decoding_iter += 1
advance_idx()
def _mixed_sample(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleState:
logits = model_outputs["logits"]
log_probs = []
new_tokens_device_array = []
idx = 0
for request in scheduled_requests.context_requests:
assert not request.py_return_context_logits, "Return context logits not supported"
token_logits = logits[idx:idx + 1, :]
new_token, probs = decode_single_request(request, token_logits)
new_tokens_device_array.append(new_token)
probs = [probs.tolist()] if request.py_return_log_probs else None
log_probs.append(probs) # Currently always beam_width=1
idx += 1
for request in scheduled_requests.generation_requests:
if request.state == LlmRequestState.GENERATION_COMPLETE:
continue
assert len(
request.py_draft_tokens
) == 0, "Speculative decoding not supported in SeparateDecoder."
token_logits = logits[idx:idx + 1, :]
new_token, probs = decode_single_request(request, token_logits)
new_tokens_device_array.append(new_token)
probs = [probs.tolist()] if request.py_return_log_probs else None
log_probs.append(probs) # Currently always beam_width=1
idx += 1
new_tokens_device = torch.cat(new_tokens_device_array)
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
sampler_event = torch.cuda.Event()
sampler_event.record()
return SampleState(
scheduled_requests=scheduled_requests,
logits=logits,
device=SampleStateTensors(new_tokens=new_tokens_device),
host=SampleStateTensors(new_tokens=new_tokens_host),
sampler_event=sampler_event,
log_probs=log_probs)
def _batch_sample(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleState:
logits = model_outputs["logits"]
new_tokens_device = torch.argmax(logits, dim=-1)
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
sampler_event = torch.cuda.Event()
sampler_event.record()
return SampleState(
scheduled_requests=scheduled_requests,
logits=logits,
device=SampleStateTensors(new_tokens=new_tokens_device),
host=SampleStateTensors(new_tokens=new_tokens_host),
sampler_event=sampler_event)
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleState:
if self.mixed_sampler:
return self._mixed_sample(scheduled_requests, model_outputs)
else:
return self._batch_sample(scheduled_requests, model_outputs)
class TorchStarAttentionSampler(TorchSampler):
def update_one_request(self, request: LlmRequest,
new_tokens_list: list[int], logits: torch.Tensor):
beam_idx = 0
output_token_idx = request.output_token_idx
new_token = new_tokens_list[output_token_idx]
num_tokens = request.add_new_token(new_token, beam_idx)
current_logits = logits[output_token_idx].unsqueeze(0)
def handle_logits(self, request: LlmRequest, state: SampleState, *,
beam: int, count: int):
current_slice = slice(0, count), request.seq_slot, beam
if request.py_return_generation_logits:
assert state.host.logits is not None
current_logits = state.host.logits[current_slice]
request.py_result.append_generation_logits(current_logits)
if request.py_return_log_probs:
_, log_probs = greedy_search_sampling_batch(current_logits)
request.py_result.append_log_probs([[{
new_token:
Logprob(logprob=log_probs.item(), rank=1)
}]])
assert state.host.log_probs is not None
log_probs = state.host.log_probs[request.seq_slot][beam][:count]
current_tokens = state.host.new_tokens[current_slice]
self._handle_stop_criteria(request, new_token, num_tokens, beam_idx)
if request.state != LlmRequestState.GENERATION_COMPLETE:
request.py_decoding_iter += 1
token_log_probs = [{
int(token): Logprob(logprob=logprob, rank=1)
} for token, logprob in zip(current_tokens, log_probs.tolist())]
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
request.py_result.append_log_probs([token_log_probs])
def update_requests(self, state: SampleState):
def process_draft_tokens(self, request: LlmRequest,
new_tokens: torch.Tensor, new_token: int) -> int:
num_accepted = 0
for draft_token in request.py_draft_tokens:
if draft_token != new_token:
# Reject.
break
num_accepted += 1
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
step=num_accepted)
if self._handle_stop_criteria(request, new_token, beam=self.BEAM):
break
return num_accepted
def update_requests(self, state: SampleState) -> None:
assert isinstance(state, SampleState)
if state.sampler_event:
state.sampler_event.synchronize()
new_tokens_list = state.host.new_tokens.tolist()
logits = state.logits
new_tokens = state.host.new_tokens
for request in state.scheduled_requests.context_requests:
if request.state == LlmRequestState.GENERATION_IN_PROGRESS:
self.update_one_request(request, new_tokens_list, logits)
for req in state.scheduled_requests.context_requests:
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
continue
new_token = add_token(req, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM)
self.handle_logits(req, state, beam=self.BEAM, count=1)
req.py_decoding_iter += 1
for request in state.scheduled_requests.generation_requests:
self.update_one_request(request, new_tokens_list, logits)
for req in state.scheduled_requests.generation_requests:
if req.state == LlmRequestState.GENERATION_COMPLETE:
continue
new_token = add_token(req, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM)
processed = 1
if not stop and len(req.py_draft_tokens) > 0:
num_accepted = self.process_draft_tokens(
req, new_tokens, new_token)
req.py_num_accepted_draft_tokens = num_accepted
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
processed += num_accepted
self.handle_logits(req, state, beam=self.BEAM, count=processed)
req.py_decoding_iter += 1
def log_probs_host(self, requests: Iterable[LlmRequest]):
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
if any(req.py_return_log_probs for req in requests):
return torch.empty(
(self.num_seq_slots, self.MAX_BEAM_WIDTH, self.max_tokens),
device="cpu",
pin_memory=True)
return None
def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int):
if any(req.py_return_generation_logits for req in requests):
return torch.empty((self.max_tokens, self.num_seq_slots,
self.MAX_BEAM_WIDTH, vocab_size),
device="cpu",
pin_memory=True)
return None
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor]) -> SampleState:
requests = scheduled_requests.all_requests()
new_tokens = self.store.new_tokens
vocab_size = model_outputs["logits"].shape[-1]
log_probs_host = self.log_probs_host(requests)
gen_logits_host = self.gen_logits_host(requests, vocab_size)
self._process_requests(requests,
model_outputs,
new_tokens,
gen_logits_host=gen_logits_host,
log_probs_host=log_probs_host)
new_tokens_host = new_tokens.to(device="cpu", non_blocking=True)
sampler_event = torch.cuda.Event()
sampler_event.record()
return SampleState(scheduled_requests=scheduled_requests,
device=SampleStateTensors(new_tokens=new_tokens),
host=SampleStateTensors(new_tokens=new_tokens_host,
log_probs=log_probs_host,
logits=gen_logits_host),
sampler_event=sampler_event)
@staticmethod
def append_eagle3(tokens: torch.Tensor, model_outputs):
if "d2t" in model_outputs:
d2t = model_outputs["d2t"][tokens]
tokens += d2t
def _process_requests(self,
requests: list[LlmRequest],
model_outputs: dict[str, torch.Tensor],
new_tokens: torch.Tensor,
*,
gen_logits_host: torch.Tensor | None = None,
log_probs_host: torch.Tensor | None = None):
beam_width = self.MAX_BEAM_WIDTH
beam = self.BEAM
raw_logits = model_outputs["logits"]
num_steps = [1 + len(req.py_draft_tokens) for req in requests]
sum_steps = sum(num_steps)
no_draft_tokens = len(requests) == sum_steps
fast_path = not self.mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
seq_slots = torch.as_tensor([r.seq_slot for r in requests])
seq_slots = seq_slots.to(device="cuda", non_blocking=True)
if fast_path:
logits = raw_logits[:len(requests)]
next_tokens = torch.argmax(logits, dim=-1)
self.append_eagle3(next_tokens, model_outputs)
int_next_tokens = next_tokens.to(torch.int, non_blocking=True)
next_tokens = int_next_tokens.view(1, -1, beam_width)
new_tokens[:1].index_copy_(1, seq_slots, next_tokens)
return
strategies = sampling_strategies(requests)
batched_next_tokens, batched_softmax = None, None
batched_strategy: Strategy | None = GREEDY
if self.mixed_sampler:
assert "d2t" not in model_outputs, "eagle3 does not yet support non-greedy sampling"
if len(set(strategies)) == 1:
batched_strategy = strategies[0]
else:
batched_strategy = None
if batched_strategy is not None:
logits = raw_logits[:sum_steps]
batched_next_tokens, batched_softmax = sample(
batched_strategy, logits)
self.append_eagle3(batched_next_tokens, model_outputs)
offset = 0
for strategy, slot, steps in zip(strategies, seq_slots, num_steps):
input_slice = slice(offset, offset + steps)
logits = raw_logits[input_slice]
if batched_next_tokens is None:
next_tokens, softmax = sample(strategy, logits)
else:
next_tokens = batched_next_tokens[input_slice]
softmax = batched_softmax[input_slice]
current_slice = slice(0, steps), slot, beam
new_tokens[current_slice] = next_tokens
if gen_logits_host is not None:
gen_logits_host[current_slice].copy_(logits, non_blocking=True)
if log_probs_host is not None:
assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
token_probs = torch.gather(
softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
log_probs_host[slot, beam, :steps].copy_(log_probs,
non_blocking=True)
offset += steps
class Algorithms:
@ -456,19 +461,17 @@ class Algorithms:
return f"Algs({', '.join(algs)})"
@dataclass(frozen=True, kw_only=True)
@dataclass(kw_only=True)
class SampleStateTensorsHostTRTLLM(SampleStateTensors):
finished_sum: torch.Tensor
finish_reasons: torch.Tensor
sequence_lengths: torch.Tensor
log_probs: torch.Tensor
cum_log_probs: torch.Tensor
cum_log_probs: torch.Tensor | None = None
@dataclass(kw_only=True)
class SampleStateTRTLLM(SampleState):
host: SampleStateTensorsHostTRTLLM
device: SampleStateTensors
class TRTLLMSampler(Sampler):
@ -527,13 +530,6 @@ class TRTLLMSampler(Sampler):
DecoderInputBuffers(self.max_num_sequences,
self.executor_config.max_batch_size,
self.MAX_DECODING_TOKENS, buffer_manager),
"new_tokens_device_tensor":
torch.empty((
self.executor_config.max_batch_size,
self.executor_config.max_beam_width,
),
dtype=torch.int,
device='cuda'),
"sequence_lengths_host":
torch.empty((
self.executor_config.max_batch_size,
@ -602,7 +598,6 @@ class TRTLLMSampler(Sampler):
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleStateTRTLLM:
batch_size = scheduled_requests.batch_size
beam_width = self.beam_width(scheduled_requests.all_requests)
self.setup_sampler_step(scheduled_requests.context_requests)
@ -631,20 +626,6 @@ class TRTLLMSampler(Sampler):
self.algs.decoder.forward_async(self.store["decoder_state"],
decoding_input)
# NOTE: The following code prepares a new_tokens_device_tensor in accordance with the
# current implementation of model_engine.
# TODO: When we support speculative decoding:
# new_tokens_device_tensor should be, for speculative decoding cases: [batch, 1 + draft_len], others: [batch]
new_tokens_device_tensor = self.store[
"new_tokens_device_tensor"][:batch_size, :beam_width]
seq_slots = [
request.seq_slot for request in scheduled_requests.all_requests
]
new_tokens_device_tensor.copy_(
self.store["decoder_state"].all_new_tokens[0][seq_slots],
non_blocking=True)
new_tokens_device_tensor = new_tokens_device_tensor.view(-1)
new_output_tokens = self.store["decoder_state"].all_new_tokens.to(
'cpu', non_blocking=True)
finished_sum = self.store["decoder_state"].finished_sum.to(
@ -654,16 +635,17 @@ class TRTLLMSampler(Sampler):
sequence_lengths = self.store["decoder_state"].sequence_lengths.to(
'cpu', non_blocking=True)
log_probs = torch.empty([0], dtype=torch.float, device='cpu')
cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu')
log_probs = None
cum_log_probs = None
if any(request.py_return_log_probs
for request in scheduled_requests.all_requests):
for request in scheduled_requests.all_requests()):
log_probs = self.store["decoder_state"].log_probs.to(
'cpu', non_blocking=True)
cum_log_probs = self.store["decoder_state"].cum_log_probs.to(
'cpu', non_blocking=True)
device = SampleStateTensors(new_tokens=new_tokens_device_tensor)
device = SampleStateTensors(
new_tokens=self.store["decoder_state"].all_new_tokens)
host = SampleStateTensorsHostTRTLLM(new_tokens=new_output_tokens,
finished_sum=finished_sum,
@ -676,7 +658,6 @@ class TRTLLMSampler(Sampler):
sampler_event.record()
return SampleStateTRTLLM(scheduled_requests=scheduled_requests,
logits=model_outputs["logits"],
device=device,
host=host,
sampler_event=sampler_event)
@ -687,7 +668,8 @@ class TRTLLMSampler(Sampler):
scheduled_requests = state.scheduled_requests
assert scheduled_requests.batch_size > 0
beam_width = self.beam_width(scheduled_requests.all_requests)
requests = scheduled_requests.all_requests()
beam_width = self.beam_width(requests)
sampler_event = state.sampler_event
if sampler_event:
@ -697,7 +679,7 @@ class TRTLLMSampler(Sampler):
finished_sum_host = state.host.finished_sum
sequence_lengths_host_data = state.host.sequence_lengths
for request in scheduled_requests.all_requests:
for request in requests:
if request.is_context_init_state:
continue
@ -718,17 +700,20 @@ class TRTLLMSampler(Sampler):
seq_len - request.get_num_tokens(beam))
for step in range(num_new_tokens[beam]):
new_token = new_tokens_host[step][seq_slot][beam]
request.add_new_token(new_token, beam)
new_token = add_token(request,
new_tokens_host,
beam=beam,
step=step)
if request.py_return_log_probs:
assert state.host.log_probs is not None
# NOTE: Log probs with drafting has not been tested yet.
begin_log_probs_offset = request.prompt_len if request.sampling_config.beam_width == 1 else 0
current_token = seq_len - request.prompt_len - num_new_tokens[
beam] + step
log_probs.append({
new_token.item():
new_token:
Logprob(logprob=state.host.log_probs[seq_slot][beam]
[begin_log_probs_offset +
current_token].item(),

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from collections import namedtuple
from itertools import chain
from typing import Optional
from tensorrt_llm.bindings import executor as tb_executor
@ -36,9 +35,8 @@ class ScheduledRequests:
def batch_size(self) -> int:
return len(self.context_requests) + len(self.generation_requests)
@property
def all_requests(self) -> chain[LlmRequest]:
return chain(self.context_requests, self.generation_requests)
def all_requests(self) -> list[LlmRequest]:
return self.context_requests + self.generation_requests
class RequestScheduler(ABC):

View File

@ -1,5 +1,3 @@
import itertools
from .llm_request import LlmRequest
from .resource_manager import BaseResourceManager, SlotManager
from .scheduler import ScheduledRequests
@ -17,10 +15,8 @@ class SeqSlotManager(BaseResourceManager):
return 1
def prepare_resources(self, scheduled_batch: ScheduledRequests) -> None:
for llm_req in itertools.chain(scheduled_batch.context_requests,
scheduled_batch.generation_requests):
if (llm_req.is_context_init_state and llm_req.seq_slot is None) or \
llm_req.is_disagg_generation_transmission_complete:
for llm_req in scheduled_batch.all_requests():
if llm_req.seq_slot is None or llm_req.is_disagg_generation_transmission_complete:
llm_req.seq_slot = self.slot_manager.add_slot(
llm_req.request_id)
if llm_req.return_perf_metrics:

View File

@ -10,7 +10,7 @@ from tensorrt_llm.mapping import Mapping
from ..attention_backend import AttentionMetadata
from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler
from ..pyexecutor.sampler import TorchSampler
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
from .mtp import MTPSampler
@ -214,26 +214,6 @@ class Eagle3SpecMetadata(SpecMetadata):
return hidden_states
class Eagle3Sampler(TorchSampler):
def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState:
logits = model_outputs["logits"]
new_tokens_device = torch.argmax(logits, dim=-1)
if "d2t" in model_outputs:
d2t = model_outputs["d2t"]
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
device = SampleStateTensors(new_tokens=new_tokens_device)
host = SampleStateTensors(
new_tokens=new_tokens_device.to('cpu', non_blocking=True))
sampler_event = torch.cuda.Event()
sampler_event.record()
return SampleState(scheduled_requests=scheduled_requests,
logits=logits,
device=device,
host=host,
sampler_event=sampler_event)
@dataclass
class Eagle3OneModelSpecMetadata(SpecMetadata):
# The hidden states
@ -299,31 +279,10 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
break
class Eagle3Decoder(TorchSampler):
class Eagle3OneModelSampler(MTPSampler):
def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState:
logits = model_outputs["logits"]
new_tokens_device = torch.argmax(logits, dim=-1)
if "d2t" in model_outputs:
d2t = model_outputs["d2t"]
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
new_tensors_device = {"new_tokens_device": new_tokens_device}
new_tensors_host = {"new_tokens_host": new_tokens_host}
decoder_event = torch.cuda.Event()
decoder_event.record()
return SampleState(scheduled_requests=scheduled_requests,
logits=logits,
new_tensors_device=new_tensors_device,
new_tensors_host=new_tensors_host,
decoder_event=decoder_event)
class Eagle3OneModelDecoder(MTPSampler):
def __init__(self, max_seq_len: int, config: Eagle3Config):
super().__init__(max_seq_len, None)
self.draft_len = config.max_draft_tokens
def __init__(self, args: TorchSampler.Args):
super().__init__(args, nextn=args.max_draft_tokens)
class Eagle3OneModelWorker(nn.Module):

View File

@ -14,7 +14,7 @@ from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
@dataclass(frozen=True, kw_only=True)
@dataclass(kw_only=True)
class SampleStateTensorsMTP(SampleStateTensors):
new_tokens_lens: torch.Tensor
next_draft_tokens: torch.Tensor
@ -248,12 +248,10 @@ class MTPSampler(TorchSampler):
SampleState = SampleStateMTP
def __init__(self, max_seq_len: int, config: MTPConfig):
super().__init__(max_seq_len, False)
def __init__(self, args: TorchSampler.Args, *, nextn: int):
super().__init__(args)
self.mapping = None
self.draft_len = 0
if config is not None:
self.draft_len = config.num_nextn_predict_layers
self.draft_len = nextn
def _draft_meet_max_token_stop_criteria(self, request: LlmRequest,
num_tokens: int, beam_idx: int):
@ -283,8 +281,9 @@ class MTPSampler(TorchSampler):
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[idx][0]
num_tokens = request.add_new_token(new_token, beam_idx)
should_stop = self._handle_stop_criteria(
request, new_token, num_tokens, beam_idx)
should_stop = self._handle_stop_criteria(request,
new_token,
beam=beam_idx)
if self._draft_meet_max_token_stop_criteria(
request, num_tokens, beam_idx):
should_stop = True
@ -303,8 +302,9 @@ class MTPSampler(TorchSampler):
for i in range(num_new_tokens):
new_token = new_tokens[i]
num_tokens = request.add_new_token(new_token, beam_idx)
should_stop = self._handle_stop_criteria(
request, new_token, num_tokens, beam_idx)
should_stop = self._handle_stop_criteria(request,
new_token,
beam=beam_idx)
if should_stop:
break
if self._draft_meet_max_token_stop_criteria(
@ -344,7 +344,6 @@ class MTPSampler(TorchSampler):
for request in scheduled_requests.context_requests:
request.py_draft_tokens = [1] * self.draft_len
return SampleStateMTP(scheduled_requests=scheduled_requests,
logits=model_outputs['logits'],
device=device,
host=host,
sampler_event=sampler_event)

View File

@ -1,6 +1,9 @@
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
from tensorrt_llm._torch.speculative.interface import SpecConfig
from .draft_target import DraftTargetSpecMetadata
from .eagle3 import (Eagle3OneModelDecoder, Eagle3OneModelSpecMetadata,
Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3Sampler,
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
Eagle3OneModelWorker, Eagle3ResourceManager,
Eagle3SpecMetadata)
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
MTPSpecMetadata, MTPWorker)
@ -97,14 +100,17 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
return None
def get_spec_decoder(max_seq_len, spec_config):
def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
if spec_config.spec_dec_mode.is_mtp():
return MTPSampler(max_seq_len, spec_config)
return MTPSampler(sampler_args,
nextn=spec_config.num_nextn_predict_layers)
if spec_config.spec_dec_mode.is_eagle3():
return Eagle3Sampler(max_seq_len)
# TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process
return TorchSampler(sampler_args)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelDecoder(max_seq_len, spec_config)
return None
return Eagle3OneModelSampler(sampler_args)
raise ValueError(
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
def get_spec_drafter(model_engine, spec_resource_manager=None):